Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation instructions/links #1126

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airgen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ repository = { workspace = true }
[dependencies]
powdr-ast = { path = "../ast" }
powdr-number = { path = "../number" }
powdr-analysis = { path = "../analysis" }

log = "0.4.17"
72 changes: 68 additions & 4 deletions airgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use powdr_ast::{
},
};

use powdr_analysis::utils::parse_pil_statement;

const MAIN_MACHINE: &str = "::Main";
const MAIN_FUNCTION: &str = "main";

Expand Down Expand Up @@ -56,15 +58,47 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
instances.push((location, ty));
}

// count incoming permutations for each machine.
let mut incoming_permutations = instances
.iter()
.map(|(location, _)| (location.clone(), 0))
.collect();

// visit the tree compiling the machines
let objects = instances
let mut objects: BTreeMap<_, _> = instances
.into_iter()
.map(|(location, ty)| {
let object = ASMPILConverter::convert_machine(&location, &ty, &input);
let object = ASMPILConverter::convert_machine(
&location,
&ty,
&input,
&mut incoming_permutations,
);
(location, object)
})
.collect();

// add pil code for the selector array and related constraints
for (location, count) in incoming_permutations {
let obj = objects.get_mut(&location).unwrap();
if obj.has_pc {
// VMs don't have call_selectors
continue;
}
assert!(
count == 0 || obj.call_selectors.is_some(),
"block machine {location} has incoming permutations but doesn't declare call_selectors"
);
if let Some(call_selectors) = obj.call_selectors.as_deref() {
obj.pil.extend([
parse_pil_statement(&format!("col witness {call_selectors}[{count}];")),
parse_pil_statement(&format!(
"std::array::map({call_selectors}, std::utils::force_bool);"
)),
]);
}
}

let Item::Machine(main_ty) = input.items.get(&main_ty).unwrap() else {
panic!()
};
Expand All @@ -73,6 +107,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
location: main_location,
latch: main_ty.latch.clone(),
operation_id: main_ty.operation_id.clone(),
call_selectors: main_ty.call_selectors.clone(),
};
let entry_points = main_ty
.operations()
Expand Down Expand Up @@ -109,15 +144,22 @@ struct ASMPILConverter<'a> {
items: &'a BTreeMap<AbsoluteSymbolPath, Item>,
pil: Vec<PilStatement>,
submachines: Vec<SubmachineDeclaration>,
/// keeps track of the total count of incoming permutations for a given machine.
incoming_permutations: &'a mut BTreeMap<Location, u64>,
}

impl<'a> ASMPILConverter<'a> {
fn new(location: &'a Location, input: &'a AnalysisASMFile) -> Self {
fn new(
location: &'a Location,
input: &'a AnalysisASMFile,
incoming_permutations: &'a mut BTreeMap<Location, u64>,
) -> Self {
Self {
location,
items: &input.items,
pil: Default::default(),
submachines: Default::default(),
incoming_permutations,
}
}

Expand All @@ -129,8 +171,9 @@ impl<'a> ASMPILConverter<'a> {
location: &'a Location,
ty: &'a AbsoluteSymbolPath,
input: &'a AnalysisASMFile,
incoming_permutations: &'a mut BTreeMap<Location, u64>,
) -> Object {
Self::new(location, input).convert_machine_inner(ty)
Self::new(location, input, incoming_permutations).convert_machine_inner(ty)
}

fn convert_machine_inner(mut self, ty: &AbsoluteSymbolPath) -> Object {
Expand All @@ -152,6 +195,8 @@ impl<'a> ASMPILConverter<'a> {
self.handle_pil_statement(block);
}

let call_selectors = input.call_selectors;
let has_pc = input.pc.is_some();
let links = input
.links
.into_iter()
Expand All @@ -162,6 +207,9 @@ impl<'a> ASMPILConverter<'a> {
degree,
pil: self.pil,
links,
latch: input.latch,
call_selectors,
has_pc,
}
}

Expand All @@ -176,6 +224,7 @@ impl<'a> ASMPILConverter<'a> {
callable,
params,
},
is_permutation,
}: LinkDefinitionStatement,
) -> Link {
let from = LinkFrom {
Expand All @@ -198,6 +247,18 @@ impl<'a> ASMPILConverter<'a> {
// get the instance location from the current location joined with the instance name
let instance_location = self.location.clone().join(instance);

let mut selector_idx = None;

if is_permutation {
// increase the permutation count into the destination machine
let count = self
.incoming_permutations
.get_mut(&instance_location)
.unwrap();
selector_idx = Some(*count);
*count += 1;
}

Link {
from,
to: instance_ty
Expand All @@ -207,16 +268,19 @@ impl<'a> ASMPILConverter<'a> {
machine: powdr_ast::object::Machine {
location: instance_location,
latch: instance_ty.latch.clone(),
call_selectors: instance_ty.call_selectors.clone(),
operation_id: instance_ty.operation_id.clone(),
},
operation: Operation {
name: d.name.to_string(),
id: d.operation.id.id.clone(),
params: d.operation.params.clone(),
},
selector_idx,
})
.unwrap()
.clone(),
is_permutation,
}
}
}
54 changes: 50 additions & 4 deletions analysis/src/machine_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl TypeChecker {
let mut errors = vec![];

let mut degree = None;
let mut call_selectors = None;
let mut registers = vec![];
let mut pil = vec![];
let mut instructions = vec![];
Expand All @@ -56,6 +57,15 @@ impl TypeChecker {
degree: degree_value,
});
}
MachineStatement::CallSelectors(_, sel) => {
if let Some(other_sel) = &call_selectors {
errors.push(format!(
"Machine {ctx} already has call_selectors ({other_sel})"
));
} else {
call_selectors = Some(sel);
}
}
MachineStatement::RegisterDeclaration(source, name, flag) => {
let ty = match flag {
Some(RegisterFlag::IsAssignment) => RegisterTy::Assignment,
Expand All @@ -75,8 +85,20 @@ impl TypeChecker {
Err(e) => errors.extend(e),
}
}
MachineStatement::LinkDeclaration(source, LinkDeclaration { flag, to }) => {
links.push(LinkDefinitionStatement { source, flag, to });
MachineStatement::LinkDeclaration(
source,
LinkDeclaration {
flag,
to,
is_permutation,
},
) => {
links.push(LinkDefinitionStatement {
source,
flag,
to,
is_permutation,
});
}
MachineStatement::Pil(_source, statement) => {
pil.push(statement);
Expand Down Expand Up @@ -232,9 +254,15 @@ impl TypeChecker {
ctx
));
}
if call_selectors.is_some() {
errors.push(format!(
"Machine {} should not have call_selectors as it has a pc",
ctx
));
}
for l in &links {
errors.push(format!(
"Machine {} should not have links as it has a pc, found `{}`. Use an external instruction instead.",
"Machine {} should not have links as it has a pc, found `{}`. Use an external instruction instead",
ctx, l.flag
));
}
Expand All @@ -254,6 +282,7 @@ impl TypeChecker {
degree,
latch,
operation_id,
call_selectors,
pc: registers
.iter()
.enumerate()
Expand Down Expand Up @@ -436,7 +465,7 @@ machine Main {
expect_check_str(
src,
Err(vec![
"Machine ::Main should not have links as it has a pc, found `foo`. Use an external instruction instead.",
"Machine ::Main should not have links as it has a pc, found `foo`. Use an external instruction instead",
]),
);
}
Expand Down Expand Up @@ -473,4 +502,21 @@ machine Arith(latch, _) {
"#;
expect_check_str(src, Err(vec!["Operation `add` in machine ::Arith can't have an operation id because the machine does not have an operation id column"]));
}

#[test]
fn virtual_machine_has_no_call_selectors() {
let src = r#"
machine Main {
reg pc[@pc];

call_selectors sel;
}
"#;
expect_check_str(
src,
Err(vec![
"Machine ::Main should not have call_selectors as it has a pc",
]),
);
}
}
13 changes: 12 additions & 1 deletion asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
&params,
body,
),
InstructionBody::CallableRef(callable) => {
InstructionBody::CallablePlookup(callable) => {
let link = self.handle_external_instruction_def(
s.source,
instruction_flag,
Expand All @@ -340,6 +340,16 @@ impl<T: FieldElement> ASMPILConverter<T> {
);
input.links.push(link);
}
InstructionBody::CallablePermutation(callable) => {
let mut link = self.handle_external_instruction_def(
s.source,
instruction_flag,
&params,
callable,
);
link.is_permutation = true;
input.links.push(link);
}
}

let inputs: Vec<_> = params
Expand Down Expand Up @@ -573,6 +583,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
source,
flag: direct_reference(flag),
to: callable,
is_permutation: false,
}
}

Expand Down
8 changes: 7 additions & 1 deletion ast/src/asm_analysis/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ impl Display for Machine {

impl Display for LinkDefinitionStatement {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "link {} => {};", self.flag, self.to)
write!(
f,
"link {} {} {};",
self.flag,
if self.is_permutation { "~>" } else { "=>" },
self.to
)
}
}

Expand Down
4 changes: 4 additions & 0 deletions ast/src/asm_analysis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ pub struct LinkDefinitionStatement {
pub flag: Expression,
/// the callable to invoke when the flag is on. TODO: check this during type checking
pub to: CallableRef,
/// true if this is a permutation link
pub is_permutation: bool,
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -691,6 +693,8 @@ pub struct Machine {
pub latch: Option<String>,
/// The operation id, i.e. the column whose values determine which operation is being invoked in the current block. Must be defined in one of the constraint blocks of this machine.
pub operation_id: Option<String>,
/// call selector array
pub call_selectors: Option<String>,
/// The set of registers for this machine
pub registers: Vec<RegisterDeclarationStatement>,
/// The index of the program counter in the registers, if any
Expand Down
12 changes: 12 additions & 0 deletions ast/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ pub struct Object {
pub pil: Vec<PilStatement>,
/// the links from this machine to its children
pub links: Vec<Link>,
/// name of the latch column
pub latch: Option<String>,
/// call selector array
pub call_selectors: Option<String>,
/// true if this machine has a PC
pub has_pc: bool,
}

impl Object {
Expand All @@ -64,6 +70,8 @@ pub struct Link {
pub from: LinkFrom,
/// the link target, i.e. a callable in some machine
pub to: LinkTo,
/// true if this is a permutation link
pub is_permutation: bool,
}

#[derive(Clone)]
Expand All @@ -78,6 +86,8 @@ pub struct LinkTo {
pub machine: Machine,
/// the operation we link to
pub operation: Operation,
/// index into the permutation selector (None if lookup)
pub selector_idx: Option<u64>,
}

#[derive(Clone)]
Expand All @@ -86,6 +96,8 @@ pub struct Machine {
pub location: Location,
/// its latch
pub latch: Option<String>,
/// call selector array
pub call_selectors: Option<String>,
/// its operation id
pub operation_id: Option<String>,
}
Expand Down
Loading
Loading