Skip to content

Commit

Permalink
Merge pull request #113 from sine-fdn/v0.2
Browse files Browse the repository at this point in the history
Hide low-level type details when using `compile` + `GarbleProgram`
  • Loading branch information
fkettelhoit committed Jan 23, 2024
2 parents 223243d + 364e88a commit 449ef65
Show file tree
Hide file tree
Showing 9 changed files with 493 additions and 297 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "garble_lang"
version = "0.1.8"
version = "0.2.0"
edition = "2021"
rust-version = "1.60.0"
description = "Turing-Incomplete Programming Language for Multi-Party Computation with Garbled Circuits"
Expand Down
107 changes: 105 additions & 2 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
use crate::{compile::wires_as_unsigned, env::Env, token::MetaInfo};
use std::collections::HashMap;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

// This module currently implements a few basic kinds of circuit optimizations:
//
// 1. Constant evaluation (e.g. x ^ 0 == x; x & 1 == x; x & 0 == 0)
// 2. Sub-expression sharing (wires are re-used if a gate with the same type and inputs exists)
// 3. Pruning of useless gates (gates that are not part of the output nor used by other gates)

const PRINT_OPTIMIZATION_RATIO: bool = false;
const MAX_GATES: usize = (u32::MAX >> 4) as usize;
const MAX_AND_GATES: usize = (u32::MAX >> 8) as usize;

/// Data type to uniquely identify gates.
pub type GateIndex = usize;

/// Description of a gate executed under S-MPC.
/// Description of a gate executed under MPC.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Gate {
/// A logical XOR gate attached to the two specified input wires.
Xor(GateIndex, GateIndex),
Expand All @@ -25,7 +31,7 @@ pub enum Gate {
Not(GateIndex),
}

/// Representation of a circuit evaluated by an S-MPC engine.
/// Representation of a circuit evaluated by an MPC engine.
///
/// Each circuit consists of 3 parts:
///
Expand Down Expand Up @@ -66,6 +72,7 @@ pub enum Gate {
/// true and constant false, specified as `Gate::Xor(0, 0)` with wire `n` and `Gate::Not(n)` (and
/// thus depend on the first input bit for their specifications).
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Circuit {
/// The different parties, with `usize` at index `i` as the number of input bits for party `i`.
pub input_gates: Vec<usize>,
Expand All @@ -75,7 +82,103 @@ pub struct Circuit {
pub output_gates: Vec<GateIndex>,
}

/// An input wire or a gate operating on them.
pub enum Wire {
/// An input wire, with its value coming directly from one of the parties.
Input(GateIndex),
/// A logical XOR gate attached to the two specified input wires.
Xor(GateIndex, GateIndex),
/// A logical AND gate attached to the two specified input wires.
And(GateIndex, GateIndex),
/// A logical NOT gate attached to the specified input wire.
Not(GateIndex),
}

/// Errors occurring during the validation or the execution of the MPC protocol.
#[derive(Debug, PartialEq, Eq)]
pub enum CircuitError {
/// The gate with the specified wire contains invalid gate connections.
InvalidGate(usize),
/// The specified output gate does not exist in the circuit.
InvalidOutput(usize),
/// The circuit does not specify any output gates.
EmptyOutputs,
/// The provided circuit has too many gates to be processed.
MaxCircuitSizeExceeded,
/// The provided index does not correspond to any party.
PartyIndexOutOfBounds,
}

impl Circuit {
/// Returns all the wires (inputs + gates) in the circuit, in ascending order.
pub fn wires(&self) -> Vec<Wire> {
let mut gates = vec![];
for (party, inputs) in self.input_gates.iter().enumerate() {
for _ in 0..*inputs {
gates.push(Wire::Input(party))
}
}
for gate in self.gates.iter() {
let gate = match gate {
Gate::Xor(x, y) => Wire::Xor(*x, *y),
Gate::And(x, y) => Wire::And(*x, *y),
Gate::Not(x) => Wire::Not(*x),
};
gates.push(gate);
}
gates
}

/// Returns the number of AND gates in the circuit.
pub fn and_gates(&self) -> usize {
self.gates
.iter()
.filter(|g| matches!(g, Gate::And(_, _)))
.count()
}

/// Checks that the circuit only uses valid wires, includes no cycles, has outputs, etc.
pub fn validate(&self) -> Result<(), CircuitError> {
let mut num_and_gates = 0;
let wires = self.wires();
for (i, g) in wires.iter().enumerate() {
match g {
Wire::Input(_) => {}
&Wire::Xor(x, y) => {
if x >= i || y >= i {
return Err(CircuitError::InvalidGate(i));
}
}
&Wire::And(x, y) => {
if x >= i || y >= i {
return Err(CircuitError::InvalidGate(i));
}
num_and_gates += 1;
}
&Wire::Not(x) => {
if x >= i {
return Err(CircuitError::InvalidGate(i));
}
}
}
}
if self.output_gates.is_empty() {
return Err(CircuitError::EmptyOutputs);
}
for &o in self.output_gates.iter() {
if o >= wires.len() {
return Err(CircuitError::InvalidOutput(o));
}
}
if num_and_gates > MAX_AND_GATES {
return Err(CircuitError::MaxCircuitSizeExceeded);
}
if wires.len() > MAX_GATES {
return Err(CircuitError::MaxCircuitSizeExceeded);
}
Ok(())
}

/// Evaluates the circuit with the specified inputs (with one `Vec<bool>` per party).
///
/// Assumes that the inputs have been previously type-checked and **panics** if the number of
Expand Down
5 changes: 5 additions & 0 deletions src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ pub enum EvalError {
UnexpectedNumberOfInputsFromParty(usize),
/// An input literal could not be parsed.
LiteralParseError(CompileTimeError),
/// The circuit does not have an input argument with the given index.
InvalidArgIndex(usize),
/// The literal is not of the expected parameter type.
InvalidLiteralType(Literal, Type),
/// The number of output bits does not match the expected type.
Expand All @@ -68,6 +70,9 @@ impl std::fmt::Display for EvalError {
EvalError::LiteralParseError(err) => {
err.fmt(f)
}
EvalError::InvalidArgIndex(i) => {
f.write_fmt(format_args!("The circuit does not an input argument with index {i}"))
}
EvalError::InvalidLiteralType(literal, ty) => {
f.write_fmt(format_args!("The argument literal is not of type {ty}: '{literal}'"))
}
Expand Down
136 changes: 127 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,59 @@
//! A purely functional programming language with a Rust-like syntax that compiles to logic gates
//! for secure multi-party computation.
//!
//! Garble programs always terminate and are compiled into a combination of boolean AND / XOR / NOT
//! gates. These boolean circuits can either be executed directly (mostly for testing purposes) or
//! passed to a multi-party computation engine.
//!
//! ```rust
//! use garble_lang::{compile, literal::Literal, token::UnsignedNumType::U32};
//!
//! // Compile and type-check a simple program to add the inputs of 3 parties:
//! let code = "pub fn main(x: u32, y: u32, z: u32) -> u32 { x + y + z }";
//! let prg = compile(code).map_err(|e| e.prettify(&code)).unwrap();
//!
//! // We can evaluate the circuit directly, useful for testing purposes:
//! let mut eval = prg.evaluator();
//! eval.set_u32(2);
//! eval.set_u32(10);
//! eval.set_u32(100);
//! let output = eval.run().map_err(|e| e.prettify(&code)).unwrap();
//! assert_eq!(u32::try_from(output).map_err(|e| e.prettify(&code)).unwrap(), 2 + 10 + 100);
//!
//! // Or we can run the compiled circuit in an MPC engine, simulated using `prg.circuit.eval()`:
//! let x = prg.parse_arg(0, "2u32").unwrap().as_bits();
//! let y = prg.parse_arg(1, "10u32").unwrap().as_bits();
//! let z = prg.parse_arg(2, "100u32").unwrap().as_bits();
//! let output = prg.circuit.eval(&[x, y, z]); // use your own MPC engine here instead
//! let result = prg.parse_output(&output).unwrap();
//! assert_eq!("112u32", result.to_string());
//!
//! // Input arguments can also be constructed directly as literals:
//! let x = prg.literal_arg(0, Literal::NumUnsigned(2, U32)).unwrap().as_bits();
//! let y = prg.literal_arg(1, Literal::NumUnsigned(10, U32)).unwrap().as_bits();
//! let z = prg.literal_arg(2, Literal::NumUnsigned(100, U32)).unwrap().as_bits();
//! let output = prg.circuit.eval(&[x, y, z]); // use your own MPC engine here instead
//! let result = prg.parse_output(&output).unwrap();
//! assert_eq!(Literal::NumUnsigned(112, U32), result);
//! ```

#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]

use ast::{Expr, FnDef, Pattern, Program, Stmt, Type, VariantExpr};
use check::TypeError;
use circuit::Circuit;
use compile::CompilerError;
use eval::EvalError;
use eval::{EvalError, Evaluator};
use literal::Literal;
use parse::ParseError;
use scan::{scan, ScanError};
use std::fmt::Write as _;
use std::fmt::{Display, Write as _};
use token::MetaInfo;

use ast::{Expr, FnDef, Pattern, Program, Stmt, Type, VariantExpr};
use circuit::Circuit;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

/// [`crate::ast::Program`] without any associated type information.
pub type UntypedProgram = Program<()>;
Expand Down Expand Up @@ -56,12 +95,91 @@ pub fn check(prg: &str) -> Result<TypedProgram, Error> {
Ok(scan(prg)?.parse()?.type_check()?)
}

/// Scans, parses, type-checks and then compiles a program to a circuit of gates.
pub fn compile(prg: &str, fn_name: &str) -> Result<(TypedProgram, TypedFnDef, Circuit), Error> {
/// Scans, parses, type-checks and then compiles the `"main"` fn of a program to a boolean circuit.
pub fn compile(prg: &str) -> Result<GarbleProgram, Error> {
let program = check(prg)?;
let (circuit, main_fn) = program.compile(fn_name)?;
let main_fn = main_fn.clone();
Ok((program, main_fn, circuit))
let (circuit, main) = program.compile("main")?;
let main = main.clone();
Ok(GarbleProgram {
program,
main,
circuit,
})
}

/// The result of type-checking and compiling a Garble program.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GarbleProgram {
/// The type-checked represenation of the full program.
pub program: TypedProgram,
/// The function to be executed as a circuit.
pub main: TypedFnDef,
/// The compilation output, as a circuit of boolean gates.
pub circuit: Circuit,
}

/// An input argument for a Garble program and circuit.
#[derive(Debug, Clone)]
pub struct GarbleArgument<'a>(Literal, &'a TypedProgram);

impl GarbleProgram {
/// Returns an evaluator that can be used to run the compiled circuit.
pub fn evaluator(&self) -> Evaluator<'_> {
Evaluator::new(&self.program, &self.main, &self.circuit)
}

/// Type-checks and uses the literal as the circuit input argument with the given index.
pub fn literal_arg(
&self,
arg_index: usize,
literal: Literal,
) -> Result<GarbleArgument<'_>, EvalError> {
let Some(param) = self.main.params.get(arg_index) else {
return Err(EvalError::InvalidArgIndex(arg_index));
};
if !literal.is_of_type(&self.program, &param.ty) {
return Err(EvalError::InvalidLiteralType(literal, param.ty.clone()));
}
Ok(GarbleArgument(literal, &self.program))
}

/// Tries to parse the string as the circuit input argument with the given index.
pub fn parse_arg(
&self,
arg_index: usize,
literal: &str,
) -> Result<GarbleArgument<'_>, EvalError> {
let Some(param) = self.main.params.get(arg_index) else {
return Err(EvalError::InvalidArgIndex(arg_index));
};
let literal = Literal::parse(&self.program, &param.ty, literal)
.map_err(EvalError::LiteralParseError)?;
Ok(GarbleArgument(literal, &self.program))
}

/// Tries to convert the circuit output back to a Garble literal.
pub fn parse_output(&self, bits: &[bool]) -> Result<Literal, EvalError> {
Literal::from_result_bits(&self.program, &self.main.ty, bits)
}
}

impl GarbleArgument<'_> {
/// Converts the argument to input bits for the compiled circuit.
pub fn as_bits(&self) -> Vec<bool> {
self.0.as_bits(self.1)
}

/// Converts the argument to a Garble literal.
pub fn as_literal(&self) -> Literal {
self.0.clone()
}
}

impl Display for GarbleArgument<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// Errors that can occur during compile time, while a program is scanned, parsed or type-checked.
Expand Down
Loading

0 comments on commit 449ef65

Please sign in to comment.