Skip to content

Commit

Permalink
Merge pull request #284 from nikomatsakis/generic-rust-ir
Browse files Browse the repository at this point in the history
make chalk-rust-ir generic over type-family
  • Loading branch information
nikomatsakis committed Nov 16, 2019
2 parents 102eba3 + 1b38304 commit a88cad7
Show file tree
Hide file tree
Showing 38 changed files with 873 additions and 841 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 104 additions & 0 deletions chalk-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,110 @@ fn derive_fold_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream {
}
}

#[proc_macro_derive(HasTypeFamily, attributes(has_type_family))]
pub fn derive_has_type_family(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let (impl_generics, ty_generics, where_clause_ref) = input.generics.split_for_impl();

let type_name = input.ident;

if let Some(attr) = input
.attrs
.iter()
.find(|a| a.path.is_ident("has_type_family"))
{
// Hardcoded type-family:
//
// impl HasTypeFamily for Type {
// type Result = XXX;
// }
let arg = attr
.parse_args::<proc_macro2::TokenStream>()
.expect("Expected has_type_family argument");

return TokenStream::from(quote! {
impl #impl_generics HasTypeFamily for #type_name #ty_generics #where_clause_ref {
type TypeFamily = #arg;
}
});
}

match input.generics.params.len() {
1 => {}

0 => {
panic!(
"TypeFamily derive requires a single type parameter or a `#[has_type_family]` attr"
);
}

_ => {
panic!("TypeFamily derive only works with a single type parameter");
}
};

let generic_param0 = &input.generics.params[0];

if let Some(param) = has_type_family(&generic_param0) {
// HasTypeFamily bound:
//
// Example:
//
// impl<T, _TF> HasTypeFamily for Binders<T>
// where
// T: HasTypeFamily<TypeFamily = _TF>,
// _TF: TypeFamily,
// {
// type Result = _TF;
// }

let mut impl_generics = input.generics.clone();
impl_generics.params.extend(vec![GenericParam::Type(
syn::parse(quote! { _TF: TypeFamily }.into()).unwrap(),
)]);

let mut where_clause = where_clause_ref
.cloned()
.unwrap_or_else(|| syn::parse2(quote![where]).unwrap());
where_clause
.predicates
.push(syn::parse2(quote! { #param: HasTypeFamily<TypeFamily = _TF> }).unwrap());

return TokenStream::from(quote! {
impl #impl_generics HasTypeFamily for #type_name < #param >
#where_clause
{
type TypeFamily = _TF;
}
});
}

// TypeFamily bound:
//
// Example:
//
// impl<TF> HasTypeFamily for Foo<TF>
// where
// TF: TypeFamily,
// {
// type TypeFamily = TF;
// }

if let Some(tf) = is_type_family(&generic_param0) {
let impl_generics = &input.generics;

return TokenStream::from(quote! {
impl #impl_generics HasTypeFamily for #type_name < #tf >
#where_clause_ref
{
type TypeFamily = #tf;
}
});
}

panic!("derive(TypeFamily) requires a parameter that implements HasTypeFamily or TypeFamily");
}

/// Checks whether a generic parameter has a `: HasTypeFamily` bound
fn has_type_family(param: &GenericParam) -> Option<&Ident> {
bounded_by_trait(param, "HasTypeFamily")
Expand Down
17 changes: 10 additions & 7 deletions chalk-integration/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ impl ChalkDatabase {
Ok(chalk_parse::parse_goal(text)?.lower(&*program)?)
}

pub fn solve(&self, goal: &UCanonical<InEnvironment<Goal<ChalkIr>>>) -> Option<Solution> {
pub fn solve(
&self,
goal: &UCanonical<InEnvironment<Goal<ChalkIr>>>,
) -> Option<Solution<ChalkIr>> {
let solver = self.solver();
let solution = solver.lock().unwrap().solve(self, goal);
solution
Expand All @@ -76,28 +79,28 @@ impl ChalkDatabase {
}
}

impl RustIrDatabase for ChalkDatabase {
impl RustIrDatabase<ChalkIr> for ChalkDatabase {
fn custom_clauses(&self) -> Vec<ProgramClause<ChalkIr>> {
self.program_ir().unwrap().custom_clauses()
}

fn associated_ty_data(&self, ty: TypeId) -> Arc<AssociatedTyDatum> {
fn associated_ty_data(&self, ty: TypeId) -> Arc<AssociatedTyDatum<ChalkIr>> {
self.program_ir().unwrap().associated_ty_data(ty)
}

fn trait_datum(&self, id: TraitId) -> Arc<TraitDatum> {
fn trait_datum(&self, id: TraitId) -> Arc<TraitDatum<ChalkIr>> {
self.program_ir().unwrap().trait_datum(id)
}

fn impl_datum(&self, id: ImplId) -> Arc<ImplDatum> {
fn impl_datum(&self, id: ImplId) -> Arc<ImplDatum<ChalkIr>> {
self.program_ir().unwrap().impl_datum(id)
}

fn associated_ty_value(&self, id: AssociatedTyValueId) -> Arc<AssociatedTyValue> {
fn associated_ty_value(&self, id: AssociatedTyValueId) -> Arc<AssociatedTyValue<ChalkIr>> {
self.program_ir().unwrap().associated_ty_values[&id].clone()
}

fn struct_datum(&self, id: StructId) -> Arc<StructDatum> {
fn struct_datum(&self, id: StructId) -> Arc<StructDatum<ChalkIr>> {
self.program_ir().unwrap().struct_datum(id)
}

Expand Down
32 changes: 16 additions & 16 deletions chalk-integration/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,15 +625,15 @@ trait LowerStructDefn {
&self,
struct_id: chalk_ir::StructId,
env: &Env,
) -> LowerResult<rust_ir::StructDatum>;
) -> LowerResult<rust_ir::StructDatum<ChalkIr>>;
}

impl LowerStructDefn for StructDefn {
fn lower_struct(
&self,
struct_id: chalk_ir::StructId,
env: &Env,
) -> LowerResult<rust_ir::StructDatum> {
) -> LowerResult<rust_ir::StructDatum<ChalkIr>> {
if self.flags.fundamental && self.all_parameters().len() != 1 {
Err(RustIrError::InvalidFundamentalTypesParameters(self.name))?;
}
Expand Down Expand Up @@ -679,11 +679,11 @@ impl LowerTraitRef for TraitRef {
}

trait LowerTraitBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::TraitBound>;
fn lower(&self, env: &Env) -> LowerResult<rust_ir::TraitBound<ChalkIr>>;
}

impl LowerTraitBound for TraitBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::TraitBound> {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::TraitBound<ChalkIr>> {
let trait_id = match env.lookup(self.trait_name)? {
NameLookup::Type(TypeKindId::TraitId(trait_id)) => trait_id,
NameLookup::Type(_) | NameLookup::Parameter(_) => {
Expand Down Expand Up @@ -728,11 +728,11 @@ impl LowerTraitBound for TraitBound {
}

trait LowerProjectionEqBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::ProjectionEqBound>;
fn lower(&self, env: &Env) -> LowerResult<rust_ir::ProjectionEqBound<ChalkIr>>;
}

impl LowerProjectionEqBound for ProjectionEqBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::ProjectionEqBound> {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::ProjectionEqBound<ChalkIr>> {
let trait_bound = self.trait_bound.lower(env)?;
let lookup = match env
.associated_ty_lookups
Expand Down Expand Up @@ -775,11 +775,11 @@ impl LowerProjectionEqBound for ProjectionEqBound {
}

trait LowerInlineBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::InlineBound>;
fn lower(&self, env: &Env) -> LowerResult<rust_ir::InlineBound<ChalkIr>>;
}

impl LowerInlineBound for InlineBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::InlineBound> {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::InlineBound<ChalkIr>> {
let bound = match self {
InlineBound::TraitBound(b) => rust_ir::InlineBound::TraitBound(b.lower(&env)?),
InlineBound::ProjectionEqBound(b) => {
Expand All @@ -791,23 +791,23 @@ impl LowerInlineBound for InlineBound {
}

trait LowerQuantifiedInlineBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::QuantifiedInlineBound>;
fn lower(&self, env: &Env) -> LowerResult<rust_ir::QuantifiedInlineBound<ChalkIr>>;
}

impl LowerQuantifiedInlineBound for QuantifiedInlineBound {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::QuantifiedInlineBound> {
fn lower(&self, env: &Env) -> LowerResult<rust_ir::QuantifiedInlineBound<ChalkIr>> {
let parameter_kinds = self.parameter_kinds.iter().map(|pk| pk.lower());
let binders = env.in_binders(parameter_kinds, |env| Ok(self.bound.lower(env)?))?;
Ok(binders)
}
}

trait LowerQuantifiedInlineBoundVec {
fn lower(&self, env: &Env) -> LowerResult<Vec<rust_ir::QuantifiedInlineBound>>;
fn lower(&self, env: &Env) -> LowerResult<Vec<rust_ir::QuantifiedInlineBound<ChalkIr>>>;
}

impl LowerQuantifiedInlineBoundVec for [QuantifiedInlineBound] {
fn lower(&self, env: &Env) -> LowerResult<Vec<rust_ir::QuantifiedInlineBound>> {
fn lower(&self, env: &Env) -> LowerResult<Vec<rust_ir::QuantifiedInlineBound<ChalkIr>>> {
self.iter().map(|b| b.lower(env)).collect()
}
}
Expand Down Expand Up @@ -1058,7 +1058,7 @@ trait LowerImpl {
empty_env: &Env,
impl_id: ImplId,
associated_ty_value_ids: &AssociatedTyValueIds,
) -> LowerResult<rust_ir::ImplDatum>;
) -> LowerResult<rust_ir::ImplDatum<ChalkIr>>;
}

impl LowerImpl for Impl {
Expand All @@ -1067,7 +1067,7 @@ impl LowerImpl for Impl {
empty_env: &Env,
impl_id: ImplId,
associated_ty_value_ids: &AssociatedTyValueIds,
) -> LowerResult<rust_ir::ImplDatum> {
) -> LowerResult<rust_ir::ImplDatum<ChalkIr>> {
debug_heading!("LowerImpl::lower_impl(impl_id={:?})", impl_id);

let polarity = self.polarity.lower();
Expand Down Expand Up @@ -1160,15 +1160,15 @@ trait LowerTrait {
&self,
trait_id: chalk_ir::TraitId,
env: &Env,
) -> LowerResult<rust_ir::TraitDatum>;
) -> LowerResult<rust_ir::TraitDatum<ChalkIr>>;
}

impl LowerTrait for TraitDefn {
fn lower_trait(
&self,
trait_id: chalk_ir::TraitId,
env: &Env,
) -> LowerResult<rust_ir::TraitDatum> {
) -> LowerResult<rust_ir::TraitDatum<ChalkIr>> {
let all_parameters = self.all_parameters();
let all_parameters_len = all_parameters.len();
let binders = env.in_binders(all_parameters, |env| {
Expand Down
22 changes: 11 additions & 11 deletions chalk-integration/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ pub struct Program {
pub type_kinds: BTreeMap<TypeKindId, TypeKind>,

/// For each struct:
pub struct_data: BTreeMap<StructId, Arc<StructDatum>>,
pub struct_data: BTreeMap<StructId, Arc<StructDatum<ChalkIr>>>,

/// For each impl:
pub impl_data: BTreeMap<ImplId, Arc<ImplDatum>>,
pub impl_data: BTreeMap<ImplId, Arc<ImplDatum<ChalkIr>>>,

/// For each associated ty value `type Foo = XXX` found in an impl:
pub associated_ty_values: BTreeMap<AssociatedTyValueId, Arc<AssociatedTyValue>>,
pub associated_ty_values: BTreeMap<AssociatedTyValueId, Arc<AssociatedTyValue<ChalkIr>>>,

/// For each trait:
pub trait_data: BTreeMap<TraitId, Arc<TraitDatum>>,
pub trait_data: BTreeMap<TraitId, Arc<TraitDatum<ChalkIr>>>,

/// For each associated ty declaration `type Foo` found in a trait:
pub associated_ty_data: BTreeMap<TypeId, Arc<AssociatedTyDatum>>,
pub associated_ty_data: BTreeMap<TypeId, Arc<AssociatedTyDatum<ChalkIr>>>,

/// For each user-specified clause
pub custom_clauses: Vec<ProgramClause<ChalkIr>>,
Expand Down Expand Up @@ -97,28 +97,28 @@ impl tls::DebugContext for Program {
}
}

impl RustIrDatabase for Program {
impl RustIrDatabase<ChalkIr> for Program {
fn custom_clauses(&self) -> Vec<ProgramClause<ChalkIr>> {
self.custom_clauses.clone()
}

fn associated_ty_data(&self, ty: TypeId) -> Arc<AssociatedTyDatum> {
fn associated_ty_data(&self, ty: TypeId) -> Arc<AssociatedTyDatum<ChalkIr>> {
self.associated_ty_data[&ty].clone()
}

fn trait_datum(&self, id: TraitId) -> Arc<TraitDatum> {
fn trait_datum(&self, id: TraitId) -> Arc<TraitDatum<ChalkIr>> {
self.trait_data[&id].clone()
}

fn impl_datum(&self, id: ImplId) -> Arc<ImplDatum> {
fn impl_datum(&self, id: ImplId) -> Arc<ImplDatum<ChalkIr>> {
self.impl_data[&id].clone()
}

fn associated_ty_value(&self, id: AssociatedTyValueId) -> Arc<AssociatedTyValue> {
fn associated_ty_value(&self, id: AssociatedTyValueId) -> Arc<AssociatedTyValue<ChalkIr>> {
self.associated_ty_values[&id].clone()
}

fn struct_datum(&self, id: StructId) -> Arc<StructDatum> {
fn struct_datum(&self, id: StructId) -> Arc<StructDatum<ChalkIr>> {
self.struct_data[&id].clone()
}

Expand Down
7 changes: 4 additions & 3 deletions chalk-integration/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::error::ChalkError;
use crate::lowering::LowerProgram;
use crate::program::Program;
use crate::program_environment::ProgramEnvironment;
use chalk_ir::family::ChalkIr;
use chalk_ir::tls;
use chalk_ir::TraitId;
use chalk_solve::clauses::builder::ClauseBuilder;
Expand All @@ -20,7 +21,7 @@ use std::sync::Arc;
use std::sync::Mutex;

#[salsa::query_group(Lowering)]
pub trait LoweringDatabase: RustIrDatabase {
pub trait LoweringDatabase: RustIrDatabase<ChalkIr> {
#[salsa::input]
fn program_text(&self) -> Arc<String>;

Expand Down Expand Up @@ -48,7 +49,7 @@ pub trait LoweringDatabase: RustIrDatabase {
/// volatile, thus ensuring that the solver is recreated in every
/// revision (i.e., each time source program changes).
#[salsa::volatile]
fn solver(&self) -> Arc<Mutex<Solver>>;
fn solver(&self) -> Arc<Mutex<Solver<ChalkIr>>>;
}

fn program_ir(db: &impl LoweringDatabase) -> Result<Arc<Program>, ChalkError> {
Expand Down Expand Up @@ -163,7 +164,7 @@ fn environment(db: &impl LoweringDatabase) -> Result<Arc<ProgramEnvironment>, Ch
Ok(Arc::new(ProgramEnvironment::new(program_clauses)))
}

fn solver(db: &impl LoweringDatabase) -> Arc<Mutex<Solver>> {
fn solver(db: &impl LoweringDatabase) -> Arc<Mutex<Solver<ChalkIr>>> {
let choice = db.solver_choice();
Arc::new(Mutex::new(choice.into_solver()))
}
Loading

0 comments on commit a88cad7

Please sign in to comment.