From ca2d4e197811d03b3c6669868a1be75a8982d446 Mon Sep 17 00:00:00 2001 From: Alexander Batashev Date: Thu, 6 Jun 2024 00:00:03 +0300 Subject: [PATCH] feat: match-style op dispatcher (#63) --- core/tests/test_matcher.rs | 32 ++++++++++ macros/src/helpers.rs | 117 +++++++++++++++++++++++++++++++++++++ macros/src/lib.rs | 7 +++ 3 files changed, 156 insertions(+) create mode 100644 core/tests/test_matcher.rs create mode 100644 macros/src/helpers.rs diff --git a/core/tests/test_matcher.rs b/core/tests/test_matcher.rs new file mode 100644 index 0000000..39a5841 --- /dev/null +++ b/core/tests/test_matcher.rs @@ -0,0 +1,32 @@ +use tir_core::{ + builtin::{ConstOp, ModuleOp}, + Context, OpRef, +}; +use tir_macros::match_op; + +#[test] +fn match_ops() { + let context = Context::new(); + let module = ModuleOp::builder(&context).build(); + let module: OpRef = module; + let module2 = module.clone(); + let module3 = module.clone(); + let res = match_op!(module { + ModuleOp => |_| true, + _ => || false, + }); + assert_eq!(res, true); + + let res = match_op!(module2 { + ConstOp => |_| true, + _ => || false, + }); + assert_eq!(res, false); + + let res = match_op!(module3 { + ConstOp => |_| false, + ModuleOp => |_| true, + _ => || false, + }); + assert_eq!(res, true); +} diff --git a/macros/src/helpers.rs b/macros/src/helpers.rs new file mode 100644 index 0000000..e056425 --- /dev/null +++ b/macros/src/helpers.rs @@ -0,0 +1,117 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{braced, parse::Parse, parse_macro_input, punctuated::Punctuated, Token}; + +struct MatchArm { + op: Option, + body: syn::Expr, +} + +impl Parse for MatchArm { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + if let Ok(_) = input.parse::() { + input.parse::]>()?; + let body: syn::Expr = input.parse()?; + + return Ok(Self { op: None, body }); + } + + let op: syn::Ident = input.parse()?; + let op = Some(op); + input.parse::]>()?; + let body: syn::Expr = input.parse()?; + + Ok(Self { op, body }) + } +} + +struct MatchInput { + target: syn::Ident, + arms: Vec, + catch_all: Option, +} + +impl Parse for MatchInput { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let target: syn::Ident = input.parse()?; + let body; + braced!(body in input); + // panic!("{:?}", body); + let all_arms = Punctuated::::parse_terminated(&body)?; + let mut arms = vec![]; + let mut catch_all = None; + + for arm in all_arms { + if arm.op.is_none() { + catch_all = Some(arm); + } else { + arms.push(arm); + } + } + + Ok(Self { + target, + arms, + catch_all, + }) + } +} + +/// Expand match_op! { ... } helper macro +pub(crate) fn op_matcher(input: TokenStream) -> TokenStream { + let match_input = parse_macro_input!(input as MatchInput); + + let mut tokens = vec![]; + + let mut it = match_input.arms.into_iter(); + + let first = it.next().unwrap(); + let ty = first.op.unwrap(); + let body = first.body; + + let target = match_input.target; + + tokens.push(quote! { + if (#target.borrow().type_id() == std::any::TypeId::of::<#ty>()) { + let concrete = tir_core::utils::op_cast::<#ty>(#target).unwrap(); + let lambda = #body; + lambda(concrete) + } + }); + + for arm in it { + let ty = arm.op.unwrap(); + let body = arm.body; + tokens.push(quote! { + else if (#target.borrow().type_id() == std::any::TypeId::of::<#ty>()) { + let concrete = tir_core::utils::op_cast::<#ty>(#target).unwrap(); + let lambda = #body; + lambda(concrete) + } + }); + } + + if let Some(catch_all) = match_input.catch_all { + let body = catch_all.body; + + tokens.push(quote! { + else { + let lambda = #body; + lambda() + } + }); + } else { + tokens.push(quote! { + else { + unreachable!() + } + }); + } + + quote! { + #( + #tokens + )* + } + .into() +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 13f1bcb..b8e2994 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,9 +1,11 @@ extern crate proc_macro; mod assembly; +mod helpers; mod op_impl; pub(crate) use assembly::*; +pub(crate) use helpers::*; pub(crate) use op_impl::*; use case_converter::camel_to_snake; @@ -612,3 +614,8 @@ pub fn uppercase(input: TokenStream) -> TokenStream { } .into() } + +#[proc_macro] +pub fn match_op(input: TokenStream) -> TokenStream { + op_matcher(input) +}