Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/internal/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fmt::{Display, Formatter};
use std::{
fmt::{Display, Formatter},
num::NonZeroU32,
};

use crate::{internal::arena::ArenaId, Interner};

Expand Down Expand Up @@ -165,32 +168,24 @@ impl From<SolvableId> for u32 {

#[repr(transparent)]
#[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)]
pub(crate) struct ClauseId(u32);
pub(crate) struct ClauseId(NonZeroU32);

impl ClauseId {
/// There is a guarentee that ClauseId(0) will always be
/// There is a guarentee that ClauseId(1) will always be
/// "Clause::InstallRoot". This assumption is verified by the solver.
pub(crate) fn install_root() -> Self {
Self(0)
}

pub(crate) fn is_null(self) -> bool {
self.0 == u32::MAX
}

pub(crate) fn null() -> ClauseId {
ClauseId(u32::MAX)
Self(unsafe { NonZeroU32::new_unchecked(1) })
}
}

impl ArenaId for ClauseId {
fn from_usize(x: usize) -> Self {
assert!(x < u32::MAX as usize, "clause id too big");
Self(x as u32)
// SAFETY: Safe because we always add 1 to the index
Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) })
}

fn to_usize(self) -> usize {
self.0 as usize
(self.0.get() - 1) as usize
}
}

Expand Down Expand Up @@ -236,3 +231,17 @@ impl ArenaId for DependenciesId {
self.0 as usize
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_clause_id_size() {
// Verify that the size of a ClauseId is the same as an Option<ClauseId>.
assert_eq!(
std::mem::size_of::<ClauseId>(),
std::mem::size_of::<Option<ClauseId>>()
);
}
}
15 changes: 15 additions & 0 deletions src/internal/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ impl<TId: ArenaId, TValue> Mapping<TId, TValue> {
previous_value
}

/// Unset a specific value in the mapping, returns the previous value.
pub fn unset(&mut self, id: TId) -> Option<TValue> {
let idx = id.to_usize();
let (chunk, offset) = Self::chunk_and_offset(idx);
if chunk >= self.chunks.len() {
return None;
}

let previous_value = self.chunks[chunk][offset].take();
if previous_value.is_some() {
self.len -= 1;
}
previous_value
}

/// Get a specific value in the mapping with bound checks
pub fn get(&self, id: TId) -> Option<&TValue> {
let (chunk, offset) = Self::chunk_and_offset(id.to_usize());
Expand Down
62 changes: 31 additions & 31 deletions src/solver/clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub(crate) struct ClauseState {
// The ids of the solvables this clause is watching
pub watched_literals: [Literal; 2],
// The ids of the next clause in each linked list that this clause is part of
pub(crate) next_watches: [ClauseId; 2],
pub(crate) next_watches: [Option<ClauseId>; 2],
}

impl ClauseState {
Expand Down Expand Up @@ -417,15 +417,15 @@ impl ClauseState {

let clause = Self {
watched_literals,
next_watches: [ClauseId::null(), ClauseId::null()],
next_watches: [None, None],
};

debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]);

clause
}

pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: ClauseId) {
pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: Option<ClauseId>) {
self.next_watches[watch_index] = linked_clause;
}

Expand All @@ -444,7 +444,7 @@ impl ClauseState {
}

#[inline]
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> ClauseId {
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> Option<ClauseId> {
if solvable_id == self.watched_literals[0].solvable_id() {
self.next_watches[0]
} else {
Expand Down Expand Up @@ -650,7 +650,7 @@ mod test {
use super::*;
use crate::{internal::arena::ArenaId, solver::decision::Decision};

fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState {
fn clause(next_clauses: [Option<ClauseId>; 2], watch_literals: [Literal; 2]) -> ClauseState {
ClauseState {
watched_literals: watch_literals,
next_watches: next_clauses,
Expand Down Expand Up @@ -691,21 +691,24 @@ mod test {
#[test]
fn test_unlink_clause_different() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(3)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(3).into(),
],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::from_usize(3)],
[None, ClauseId::from_usize(3).into()],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1208).negative(),
],
);
let clause3 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
InternalSolvableId::from_usize(1211).negative(),
InternalSolvableId::from_usize(42).negative(),
Expand All @@ -723,10 +726,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(3)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()])
}

// Unlink 1
Expand All @@ -740,24 +740,24 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

#[test]
fn test_unlink_clause_same() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(2)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(2).into(),
],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
InternalSolvableId::from_usize(1596).negative(),
InternalSolvableId::from_usize(1211).negative(),
Expand All @@ -775,10 +775,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(2)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()])
}

// Unlink 1
Expand All @@ -792,10 +789,7 @@ mod test {
InternalSolvableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

Expand All @@ -820,7 +814,10 @@ mod test {

// No conflict, still one candidate available
decisions
.try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate1.into(), false, ClauseId::from_usize(0)),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -834,7 +831,10 @@ mod test {

// Conflict, no candidates available
decisions
.try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate2.into(), false, ClauseId::install_root()),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -848,7 +848,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::requires(
Expand Down Expand Up @@ -878,7 +878,7 @@ mod test {

// Conflict, forbidden package installed
decisions
.try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1)
.try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1)
.unwrap();
let (clause, conflict, _kind) =
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions);
Expand All @@ -888,7 +888,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions)
Expand Down
20 changes: 8 additions & 12 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1435,11 +1435,8 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// solvable
let mut old_predecessor_clause_id: Option<ClauseId>;
let mut predecessor_clause_id: Option<ClauseId> = None;
let mut clause_id = self
.watches
.first_clause_watching_literal(watched_literal)
.unwrap_or(ClauseId::null());
while !clause_id.is_null() {
let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal);
while let Some(clause_id) = next_clause_id {
debug_assert!(
predecessor_clause_id != Some(clause_id),
"Linked list is circular!"
Expand All @@ -1466,8 +1463,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
predecessor_clause_id = Some(clause_id);

// Configure the next clause to visit
let this_clause_id = clause_id;
clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());
next_clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());

// Determine which watch turned false.
let (watch_index, other_watch_index) = if clause_state.watched_literals[0]
Expand All @@ -1492,7 +1488,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// If the other watch is already true, we can simply skip
// this clause.
} else if let Some(variable) = clause_state.next_unwatched_literal(
&clauses[this_clause_id.to_usize()],
&clauses[clause_id.to_usize()],
&self.learnt_clauses,
&self.cache.requirement_to_sorted_candidates,
self.decision_tracker.map(),
Expand All @@ -1501,7 +1497,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
self.watches.update_watched(
predecessor_clause_state,
clause_state,
this_clause_id,
clause_id,
watch_index,
watched_literal,
variable,
Expand All @@ -1527,20 +1523,20 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
Decision::new(
remaining_watch.solvable_id(),
remaining_watch.satisfying_value(),
this_clause_id,
clause_id,
),
level,
)
.map_err(|_| {
PropagationError::Conflict(
remaining_watch.solvable_id(),
true,
this_clause_id,
clause_id,
)
})?;

if decided {
let clause = &clauses[this_clause_id.to_usize()];
let clause = &clauses[clause_id.to_usize()];
match clause {
// Skip logging for ForbidMultipleInstances, which is so noisy
Clause::ForbidMultipleInstances(..) => {}
Expand Down
19 changes: 7 additions & 12 deletions src/solver/watch_map.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::solver::clause::Literal;
use crate::{
internal::{id::ClauseId, mapping::Mapping},
solver::clause::ClauseState,
solver::clause::{ClauseState, Literal},
};

/// A map from solvables to the clauses that are watching them
Expand All @@ -20,9 +19,7 @@ impl WatchMap {

pub(crate) fn start_watching(&mut self, clause: &mut ClauseState, clause_id: ClauseId) {
for (watch_index, watched_literal) in clause.watched_literals.into_iter().enumerate() {
let already_watching = self
.first_clause_watching_literal(watched_literal)
.unwrap_or(ClauseId::null());
let already_watching = self.first_clause_watching_literal(watched_literal);
clause.link_to_clause(watch_index, already_watching);
self.watch_literal(watched_literal, clause_id);
}
Expand All @@ -42,18 +39,16 @@ impl WatchMap {
if let Some(predecessor_clause) = predecessor_clause {
// Unlink the clause
predecessor_clause.unlink_clause(clause, previous_watch.solvable_id(), watch_index);
} else {
} else if let Some(next_watch) = clause.next_watches[watch_index] {
// This was the first clause in the chain
self.map
.insert(previous_watch, clause.next_watches[watch_index]);
self.map.insert(previous_watch, next_watch);
} else {
self.map.unset(previous_watch);
}

// Set the new watch
clause.watched_literals[watch_index] = new_watch;
let previous_clause_id = self
.map
.insert(new_watch, clause_id)
.unwrap_or(ClauseId::null());
let previous_clause_id = self.map.insert(new_watch, clause_id);
clause.next_watches[watch_index] = previous_clause_id;
}

Expand Down