Skip to content

Commit

Permalink
Add leak checks to poly. VIR and generate them also in loops
Browse files Browse the repository at this point in the history
  • Loading branch information
vfukala committed Jun 19, 2023
1 parent c4eda38 commit 1b9c33f
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 30 deletions.
33 changes: 18 additions & 15 deletions prusti-viper/src/encoder/definition_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,24 @@ impl<'p, 'v: 'p, 'tcx: 'v> Collector<'p, 'v, 'tcx> {
let leak_checked_methods = methods
.into_iter()
.map(|mut method| -> SpannedEncodingResult<vir::CfgMethod> {
let ret_index = method.basic_blocks_labels
.iter()
.position(|label| label == "return");
if ret_index.is_none() {
return Err(SpannedEncodingError::internal("encoded method does not contain a `return` label; cannot add leak checks", self.error_span));
}
let ret_index = method.block_index(ret_index.unwrap());
for identifier in &self.used_obligations {
let leak_check = self.encoder.get_obligation_leak_check(identifier)?;
let leak_check = (*leak_check).clone();
method.add_stmt(
ret_index,
leak_check
);
}
method = method.patch_statements(|stmt| -> SpannedEncodingResult::<_> {
match stmt {
vir::Stmt::LeakCheck(vir::LeakCheck { scope_id }) => {
let mut check_body = vir::Expr::Const(vir::ConstExpr { value: vir::Const::Bool(true), position: vir::Position::default() });
for identifier in &self.used_obligations {
let current_check = self.encoder.get_obligation_leak_check(identifier, scope_id)?;
check_body = vir::Expr::BinOp(vir::BinOp {
op_kind: vir::BinaryOpKind::And,
left: Box::new(check_body),
right: Box::new(current_check),
position: vir::Position::default(),
})
}
Ok(vir::Stmt::Assert(vir::Assert { expr: check_body, position: vir::Position::default() }))
},
_ => { Ok(stmt) }
}
}).unwrap();
Ok(method)
}).collect::<SpannedEncodingResult<Vec<_>>>()?;
Ok(vir::Program {
Expand Down
55 changes: 41 additions & 14 deletions prusti-viper/src/encoder/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct Encoder<'v, 'tcx: 'v> {
/// A map containing all functions: identifier → function definition.
functions: RefCell<FxHashMap<vir::FunctionIdentifier, Rc<vir::Function>>>,
obligations: RefCell<FxHashMap<vir::FunctionIdentifier, Rc<vir::Predicate>>>,
obligation_checks: RefCell<FxHashMap<vir::FunctionIdentifier, Rc<vir::Stmt>>>,
obligation_checks: RefCell<FxHashMap<vir::FunctionIdentifier, Rc<vir::ForPerm>>>,
builtin_domains: RefCell<FxHashMap<BuiltinDomainKind, vir::Domain>>,
builtin_domains_in_progress: RefCell<FxHashSet<BuiltinDomainKind>>,
builtin_methods: RefCell<FxHashMap<BuiltinMethodKind, vir::BodylessMethod>>,
Expand Down Expand Up @@ -350,7 +350,7 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> {
let mut args = vec![vir::LocalVar::new("scope_id", vir::Type::Int)];
let mut check_args = vec![];
let mut concrete_args = vec![vir::Expr::Const(vir::ConstExpr {
value: vir::Const::Int(-1),
value: vir::Const::Int(-2),
position: vir::Position::default(),
})];
for local_idx in 1..sig.skip_binder().inputs().len() {
Expand All @@ -374,18 +374,15 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> {
args: concrete_args,
formal_arguments: args.clone(),
};
let check = vir::Stmt::Assert(vir::Assert {
expr: vir::Expr::ForPerm(vir::ForPerm {
variables: check_args,
access: obligation_access,
body: Box::new(vir::Expr::Const(vir::ConstExpr {
value: vir::Const::Bool(false),
position: vir::Position::default(),
})),
let check = vir::ForPerm {
variables: check_args,
access: obligation_access,
body: Box::new(vir::Expr::Const(vir::ConstExpr {
value: vir::Const::Bool(false),
position: vir::Position::default(),
}),
})),
position: vir::Position::default(),
});
};
self.obligation_checks.borrow_mut().insert(ident.clone()/*obligation_name.clone().into()*/, Rc::new(check));
}
Ok(())
Expand All @@ -401,11 +398,41 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> {
}
}

pub(super) fn get_obligation_leak_check(&self, identifier: &vir::FunctionIdentifier) -> SpannedEncodingResult<Rc<vir::Stmt>> {
pub(super) fn get_obligation_leak_check(&self, identifier: &vir::FunctionIdentifier, scope_id: isize) -> SpannedEncodingResult<vir::Expr> {
self.ensure_obligation_encoded(identifier)?;
if self.obligation_checks.borrow().contains_key(identifier) {
let map = self.obligation_checks.borrow();
Ok(map[identifier].clone())
let check = map[identifier].clone();
Ok(match (*check).clone() {
vir::ForPerm {
variables,
access: vir::ObligationAccess {
name,
args,
formal_arguments,
},
body,
position,
} => vir::Expr::ForPerm(vir::ForPerm {
variables,
access: vir::ObligationAccess {
name,
args: args.into_iter().enumerate().map(|(i, a)| {
if i == 0 {
vir::Expr::Const(vir::ConstExpr {
value: vir::Const::Int(scope_id.try_into().unwrap()),
position: vir::Position::default(),
})
} else {
a
}
}).collect(),
formal_arguments,
},
body,
position,
})
})
} else {
unreachable!("Not found obligation check: {:?}", identifier);
}
Expand Down
4 changes: 4 additions & 0 deletions prusti-viper/src/encoder/foldunfold/requirements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ impl RequiredStmtPermissionsGetter for vir::Stmt {
base.get_required_stmt_permissions(predicates)
}

&vir::Stmt::LeakCheck(..) => {
FxHashSet::default()
}

ref x => unimplemented!("{}", x),
}
}
Expand Down
2 changes: 2 additions & 0 deletions prusti-viper/src/encoder/foldunfold/semantics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ impl ApplyOnState for vir::Stmt {
}
}

&vir::Stmt::LeakCheck(..) => {}

ref x => unimplemented!("{}", x),
}
Ok(())
Expand Down
3 changes: 3 additions & 0 deletions prusti-viper/src/encoder/procedure_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> {
self.cfg_method.add_stmts(inv_post_block_perms, stmts);
}

// TODO: also check for obligations in postconditions here
let mid_groups = if preconds.is_empty() {
// Encode the mid G group (start - G - B1 - invariant_perm - *G* - B1 - invariant_fnspec - B2 - G - B1 - end)
let mid_g = self.encode_blocks_group(
Expand Down Expand Up @@ -1173,6 +1174,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> {
&scope_ids,
)?;
self.cfg_method.add_stmts(end_body_block, stmts);
self.cfg_method.add_stmt(end_body_block, vir::Stmt::LeakCheck(vir::LeakCheck { scope_id: loop_head.index() as isize }));
}
self.cfg_method.add_stmt(
end_body_block,
Expand Down Expand Up @@ -5178,6 +5180,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> {
position: func_pos,
}),
);
self.cfg_method.add_stmt(return_cfg_block, vir::Stmt::LeakCheck(vir::LeakCheck { scope_id: -1 }));

// Assert type invariants
self.cfg_method.add_stmt(
Expand Down
31 changes: 31 additions & 0 deletions vir/defs/polymorphic/ast/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub enum Stmt {
/// * place to the enumeration instance
/// * field that encodes the variant
Downcast(Downcast),
LeakCheck(LeakCheck),
}

impl fmt::Display for Stmt {
Expand All @@ -88,6 +89,7 @@ impl fmt::Display for Stmt {
Stmt::ExpireBorrows(expire_borrows) => expire_borrows.fmt(f),
Stmt::If(if_stmt) => if_stmt.fmt(f),
Stmt::Downcast(downcast) => downcast.fmt(f),
Stmt::LeakCheck(leak_check) => leak_check.fmt(f),
}
}
}
Expand Down Expand Up @@ -450,6 +452,17 @@ impl fmt::Display for Downcast {
}
}

#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct LeakCheck {
pub scope_id: isize,
}

impl fmt::Display for LeakCheck {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "leak for scope_id = {}", self.scope_id)
}
}

impl Stmt {
pub fn is_comment(&self) -> bool {
matches!(self, Stmt::Comment(_))
Expand Down Expand Up @@ -576,6 +589,7 @@ pub trait StmtFolder {
Stmt::ExpireBorrows(expire_borrows) => self.fold_expire_borrows(expire_borrows),
Stmt::If(if_stmt) => self.fold_if(if_stmt),
Stmt::Downcast(downcast) => self.fold_downcast(downcast),
Stmt::LeakCheck(leak_check) => self.fold_leak_check(leak_check),
}
}

Expand Down Expand Up @@ -761,6 +775,10 @@ pub trait StmtFolder {
field,
})
}

fn fold_leak_check(&mut self, statement: LeakCheck) -> Stmt {
Stmt::LeakCheck(statement)
}
}

pub trait FallibleStmtFolder {
Expand Down Expand Up @@ -793,6 +811,7 @@ pub trait FallibleStmtFolder {
}
Stmt::If(if_stmt) => self.fallible_fold_if(if_stmt),
Stmt::Downcast(downcast) => self.fallible_fold_downcast(downcast),
Stmt::LeakCheck(leak_check) => self.fallible_fold_leak_check(leak_check),
}
}

Expand Down Expand Up @@ -1009,6 +1028,10 @@ pub trait FallibleStmtFolder {
field,
}))
}

fn fallible_fold_leak_check(&mut self, statement: LeakCheck) -> Result<Stmt, Self::Error> {
Ok(Stmt::LeakCheck(statement))
}
}

pub trait StmtWalker {
Expand All @@ -1035,6 +1058,7 @@ pub trait StmtWalker {
Stmt::ExpireBorrows(expire_borrows) => self.walk_expire_borrows(expire_borrows),
Stmt::If(if_stmt) => self.walk_if(if_stmt),
Stmt::Downcast(downcast) => self.walk_downcast(downcast),
Stmt::LeakCheck(leak_check) => self.walk_leak_check(leak_check),
}
}

Expand Down Expand Up @@ -1163,6 +1187,8 @@ pub trait StmtWalker {
let Downcast { base, .. } = statement;
self.walk_expr(base);
}

fn walk_leak_check(&mut self, _statement: &LeakCheck) {}
}

pub trait FallibleStmtWalker {
Expand Down Expand Up @@ -1195,6 +1221,7 @@ pub trait FallibleStmtWalker {
}
Stmt::If(if_stmt) => self.fallible_walk_if(if_stmt),
Stmt::Downcast(downcast) => self.fallible_walk_downcast(downcast),
Stmt::LeakCheck(leak_check) => self.fallible_walk_leak_check(leak_check),
}
}

Expand Down Expand Up @@ -1369,6 +1396,10 @@ pub trait FallibleStmtWalker {
self.fallible_walk_expr(base)?;
Ok(())
}

fn fallible_walk_leak_check(&mut self, _statement: &LeakCheck) -> Result<(), Self::Error> {
Ok(())
}
}

pub fn stmts_to_str(stmts: &[Stmt]) -> String {
Expand Down
5 changes: 4 additions & 1 deletion vir/src/converter/polymorphic_to_legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,10 @@ impl From<polymorphic::Stmt> for legacy::Stmt {
),
polymorphic::Stmt::Downcast(downcast) => {
legacy::Stmt::Downcast(downcast.base.into(), downcast.field.into())
}
},
polymorphic::Stmt::LeakCheck(_) => {
panic!("all leak check markers needs to removed before convering polymorphic VIR to legacy!");
},
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions vir/src/converter/type_substitution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ impl Generic for Stmt {
}
Stmt::If(if_stmt) => Stmt::If(if_stmt.substitute(map)),
Stmt::Downcast(downcast) => Stmt::Downcast(downcast.substitute(map)),
Stmt::LeakCheck(leak_check) => Stmt::LeakCheck(leak_check.substitute(map)),
}
}
}
Expand Down Expand Up @@ -799,6 +800,12 @@ impl Generic for Downcast {
}
}

impl Generic for LeakCheck {
fn substitute(self, _map: &FxHashMap<TypeVar, Type>) -> Self {
self
}
}

// method
impl Generic for CfgMethod {
fn substitute(self, map: &FxHashMap<TypeVar, Type>) -> Self {
Expand Down

0 comments on commit 1b9c33f

Please sign in to comment.