Skip to content

Commit

Permalink
feat: match-style op dispatcher (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbatashev committed Jun 5, 2024
1 parent 9df939f commit ca2d4e1
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 0 deletions.
32 changes: 32 additions & 0 deletions core/tests/test_matcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use tir_core::{
builtin::{ConstOp, ModuleOp},
Context, OpRef,
};
use tir_macros::match_op;

#[test]
fn match_ops() {
let context = Context::new();
let module = ModuleOp::builder(&context).build();
let module: OpRef = module;
let module2 = module.clone();
let module3 = module.clone();
let res = match_op!(module {
ModuleOp => |_| true,
_ => || false,
});
assert_eq!(res, true);

let res = match_op!(module2 {
ConstOp => |_| true,
_ => || false,
});
assert_eq!(res, false);

let res = match_op!(module3 {
ConstOp => |_| false,
ModuleOp => |_| true,
_ => || false,
});
assert_eq!(res, true);
}
117 changes: 117 additions & 0 deletions macros/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{braced, parse::Parse, parse_macro_input, punctuated::Punctuated, Token};

struct MatchArm {
op: Option<syn::Ident>,
body: syn::Expr,
}

impl Parse for MatchArm {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
if let Ok(_) = input.parse::<Token![_]>() {
input.parse::<Token![=>]>()?;
let body: syn::Expr = input.parse()?;

return Ok(Self { op: None, body });
}

let op: syn::Ident = input.parse()?;
let op = Some(op);
input.parse::<Token![=>]>()?;
let body: syn::Expr = input.parse()?;

Ok(Self { op, body })
}
}

struct MatchInput {
target: syn::Ident,
arms: Vec<MatchArm>,
catch_all: Option<MatchArm>,
}

impl Parse for MatchInput {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let target: syn::Ident = input.parse()?;
let body;
braced!(body in input);
// panic!("{:?}", body);
let all_arms = Punctuated::<MatchArm, Token![,]>::parse_terminated(&body)?;
let mut arms = vec![];
let mut catch_all = None;

for arm in all_arms {
if arm.op.is_none() {
catch_all = Some(arm);
} else {
arms.push(arm);
}
}

Ok(Self {
target,
arms,
catch_all,
})
}
}

/// Expand match_op! { ... } helper macro
pub(crate) fn op_matcher(input: TokenStream) -> TokenStream {
let match_input = parse_macro_input!(input as MatchInput);

let mut tokens = vec![];

let mut it = match_input.arms.into_iter();

let first = it.next().unwrap();
let ty = first.op.unwrap();
let body = first.body;

let target = match_input.target;

tokens.push(quote! {
if (#target.borrow().type_id() == std::any::TypeId::of::<#ty>()) {
let concrete = tir_core::utils::op_cast::<#ty>(#target).unwrap();
let lambda = #body;
lambda(concrete)
}
});

for arm in it {
let ty = arm.op.unwrap();
let body = arm.body;
tokens.push(quote! {
else if (#target.borrow().type_id() == std::any::TypeId::of::<#ty>()) {
let concrete = tir_core::utils::op_cast::<#ty>(#target).unwrap();
let lambda = #body;
lambda(concrete)
}
});
}

if let Some(catch_all) = match_input.catch_all {
let body = catch_all.body;

tokens.push(quote! {
else {
let lambda = #body;
lambda()
}
});
} else {
tokens.push(quote! {
else {
unreachable!()
}
});
}

quote! {
#(
#tokens
)*
}
.into()
}
7 changes: 7 additions & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
extern crate proc_macro;

mod assembly;
mod helpers;
mod op_impl;

pub(crate) use assembly::*;
pub(crate) use helpers::*;
pub(crate) use op_impl::*;

use case_converter::camel_to_snake;
Expand Down Expand Up @@ -612,3 +614,8 @@ pub fn uppercase(input: TokenStream) -> TokenStream {
}
.into()
}

#[proc_macro]
pub fn match_op(input: TokenStream) -> TokenStream {
op_matcher(input)
}

0 comments on commit ca2d4e1

Please sign in to comment.