Skip to content

Commit

Permalink
feat: add type asm extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYakimov committed May 20, 2024
1 parent 22d6711 commit 8c3e044
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 30 deletions.
2 changes: 1 addition & 1 deletion core/checks/ops/const.tir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

; test
module {
func @foo(%arg0: !void attrs = {}) -> !void attrs = {} {
func @foo(%arg0: !void) -> !void {
^entry:
const attrs = {value = <i8: 0>} -> !int attrs = {bits = <u32: 8>}
}
Expand Down
4 changes: 2 additions & 2 deletions core/checks/ops/function.tir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

module {
; CHECK-LABEL: foo
func @foo(%arg0: !void attrs = {}) -> !void attrs = {} {
func @foo(%arg0: !void) -> !void {
; CHECK: ^entry:
^entry:
; CHECK-NEXT: const
; CHECK-SAME: value = <i8: 0>
const attrs = {value = <i8: 0>} -> !void attrs = {}
const attrs = {value = <i8: 0>} -> !void
}
}
7 changes: 7 additions & 0 deletions core/src/assembly/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ where
delimited(space0, inner, space0)
}

pub fn skip_attrs(
_input: &mut ParseStream<'_>,
) -> AsmPResult<std::collections::HashMap<String, Attr>> {
let res: std::collections::HashMap<String, Attr> = HashMap::new();
Ok(res)
}

fn single_comment(input: &mut ParseStream<'_>) -> AsmPResult<()> {
(';', take_till(1.., ['\n', '\r']), line_ending)
.void()
Expand Down
7 changes: 2 additions & 5 deletions core/src/builtin/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ mod test {

let mut printer = StringPrinter::new();
constant.borrow().print(&mut printer);
assert_eq!(
printer.get(),
"const attrs = {value = <i8: 16>} -> !void attrs = {}\n"
);
assert_eq!(printer.get(), "const attrs = {value = <i8: 16>} -> !void\n");

builder.insert(&constant);
assert_eq!(
Expand All @@ -83,7 +80,7 @@ mod test {
fn parse_const() {
let ir = "
module {
const attrs = {value = <i8: 16>} -> !void attrs = {}
const attrs = {value = <i8: 16>} -> !void
}
";

Expand Down
23 changes: 19 additions & 4 deletions core/src/builtin/types.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
use crate::Printable;
use crate::{Attr, ContextRef, Ty, TyAssembly, Type};
use std::collections::HashMap;
use tir_macros::dialect_type;
use tir_macros::{dialect_type, dialect_type_with_extensions};

use crate as tir_core;

use crate::builtin::DIALECT_NAME;

dialect_type!(FuncType);
dialect_type_with_extensions!(FuncType);
dialect_type!(VoidType);
dialect_type!(IntType);
dialect_type_with_extensions!(IntType);

impl TyAssembly for VoidType {
fn print_assembly(
_attrs: &HashMap<String, tir_core::Attr>,
fmt: &mut dyn tir_core::IRFormatter,
) {
fmt.write_direct("void");
}

fn parse_assembly(
input: &mut tir_core::parser::ParseStream<'_>,
) -> tir_core::parser::AsmPResult<std::collections::HashMap<String, tir_core::Attr>> {
tir_core::parser::skip_attrs(input)
}
}

impl FuncType {
fn get_inputs_attr_name() -> &'static str {
Expand Down Expand Up @@ -123,7 +138,7 @@ mod tests {
let ty = VoidType::build(context.clone());
let mut printer = StringPrinter::new();
ty.print(&mut printer);
assert_eq!("!void attrs = {}", &printer.get());
assert_eq!("!void", &printer.get());
let ty: Type = ty.into();
assert!(ty.isa::<VoidType>());
assert!(VoidType::try_from(ty.clone()).is_ok());
Expand Down
58 changes: 40 additions & 18 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,12 @@ pub fn dialect(input: TokenStream) -> TokenStream {
})
}

#[proc_macro]
pub fn dialect_type(input: TokenStream) -> TokenStream {
let name_ident = parse_macro_input!(input as syn::Ident);
fn dialect_type_extension(name_ident: syn::Ident) -> TokenStream {
let name_string = name_ident.to_string();
let name_str = name_string.strip_suffix("Type").unwrap_or(&name_string);
let name_str = &camel_to_snake(name_str)[1..];

quote! {
#[derive(Clone)]
pub struct #name_ident {
r#type: Type,
}

impl tir_core::Ty for #name_ident {
fn get_type_name() -> &'static str {
#name_str
}

fn get_dialect_name() -> &'static str {
DIALECT_NAME
}
}

impl tir_core::TyAssembly for #name_ident {
fn print_assembly(attrs: &HashMap<String, tir_core::Attr>, fmt: &mut dyn tir_core::IRFormatter) {
// FIXME: make attrs optional
Expand All @@ -87,6 +70,30 @@ pub fn dialect_type(input: TokenStream) -> TokenStream {
tir_core::parser::attr_list(input)
}
}
}
.into()
}

fn dialect_type_base(name_ident: syn::Ident) -> TokenStream {
let name_string = name_ident.to_string();
let name_str = name_string.strip_suffix("Type").unwrap_or(&name_string);
let name_str = &camel_to_snake(name_str)[1..];

quote! {
#[derive(Clone)]
pub struct #name_ident {
r#type: Type,
}

impl tir_core::Ty for #name_ident {
fn get_type_name() -> &'static str {
#name_str
}

fn get_dialect_name() -> &'static str {
DIALECT_NAME
}
}

impl tir_core::Printable for #name_ident {
fn print(&self, fmt: &mut dyn crate::IRFormatter) {
Expand Down Expand Up @@ -147,6 +154,21 @@ pub fn dialect_type(input: TokenStream) -> TokenStream {
.into()
}

#[proc_macro]
pub fn dialect_type(input: TokenStream) -> TokenStream {
let name_ident = parse_macro_input!(input as syn::Ident);
dialect_type_base(name_ident)
}

#[proc_macro]
pub fn dialect_type_with_extensions(input: TokenStream) -> TokenStream {
let name_ident = parse_macro_input!(input as syn::Ident);
let base = dialect_type_base(name_ident.clone());
let extension = dialect_type_extension(name_ident);
let res = vec![base, extension];
TokenStream::from_iter(res)
}

#[proc_macro]
pub fn populate_dialect_ops(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Types);
Expand Down

0 comments on commit 8c3e044

Please sign in to comment.