Skip to content

Commit

Permalink
Support Unit tuples and tuples with multiple fields
Browse files Browse the repository at this point in the history
  • Loading branch information
parasyte committed Apr 10, 2023
1 parent 4ad46ce commit d8ef93f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 57 deletions.
55 changes: 38 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
use crate::parser::{Error, ErrorSource, VariantType};
use myn::utils::spanned_error;
use proc_macro::{Span, TokenStream};
use std::str::FromStr as _;
use std::{rc::Rc, str::FromStr as _};

mod parser;

#[allow(clippy::too_many_lines)]
#[proc_macro_derive(Error, attributes(error, from, source))]
pub fn derive_error(input: TokenStream) -> TokenStream {
let ast = match Error::parse(input) {
Expand All @@ -26,12 +27,20 @@ pub fn derive_error(input: TokenStream) -> TokenStream {
ErrorSource::From(index) | ErrorSource::Source(index) => {
let name = &v.name;

if v.ty == VariantType::Tuple {
// TODO: Support more than one field for #[source]
Some(format!("Self::{name}(field) => Some(field),"))
} else {
Some(format!("Self::{name} {{ {index}, ..}} => Some({index}),"))
}
Some(match &v.ty {
VariantType::Unit => format!("Self::{name} => None,"),
VariantType::Tuple => {
let index_num: usize = index.parse().unwrap_or_default();
let fields = (0..v.fields.len())
.map(|i| if i == index_num { "field," } else { "_," })
.collect::<String>();

format!("Self::{name}({fields}) => Some(field),")
}
VariantType::Struct => {
format!("Self::{name} {{ {index}, ..}} => Some({index}),")
}
})
}
ErrorSource::None => None,
})
Expand All @@ -42,17 +51,29 @@ pub fn derive_error(input: TokenStream) -> TokenStream {
.map(|v| {
let name = &v.name;
let display = &v.display;
let fields = v
.display_fields
.iter()
.map(|field| format!("{field},"))
.collect::<String>();

if v.ty == VariantType::Tuple {
// TODO: Support more than one field for #[source]
format!(r#"Self::{name}(_) => write!(f, {display:?})?,"#)
} else {
format!(r#"Self::{name} {{ {fields} .. }} => write!(f, {display:?})?,"#)
match &v.ty {
VariantType::Unit => format!("Self::{name} => write!(f, {display:?})?,"),
VariantType::Tuple => {
let fields = (0..v.fields.len())
.map(|i| {
if v.display_fields.contains(&Rc::from(format!("field_{i}"))) {
format!("field_{i},")
} else {
"_,".to_string()
}
})
.collect::<String>();
format!("Self::{name}({fields}) => write!(f, {display:?})?,")
}
VariantType::Struct => {
let display_fields = v
.display_fields
.iter()
.map(|field| format!("{field},"))
.collect::<String>();
format!("Self::{name} {{ {display_fields} .. }} => write!(f, {display:?})?,")
}
}
})
.collect::<String>();
Expand Down
97 changes: 57 additions & 40 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) struct Variant {

#[derive(Debug, PartialEq)]
pub(crate) enum VariantType {
Unit,
Tuple,
Struct,
}
Expand Down Expand Up @@ -70,61 +71,77 @@ impl Variant {
pub(crate) fn parse(input: &mut TokenIter) -> Result<Self, TokenStream> {
let attrs = input.parse_attributes()?;
let name = input.as_ident()?;
// TODO: Group is optional for unit variants.
let group = input.as_group()?;
let _ = input.expect_punct(',');

let (ty, map) = match group.delimiter() {
Delimiter::Parenthesis => (VariantType::Tuple, parse_tuple_fields(group.stream())?),
Delimiter::Brace => (VariantType::Struct, parse_struct_fields(group.stream())?),
_ => return Err(spanned_error("Unexpected delimiter", group.span())),
};

// Resolve error source.
let mut fields = HashMap::new();
let mut source = ErrorSource::None;
let num_fields = map.len();
for (key, field) in map.into_iter() {
let attrs = field
.attrs
.iter()
.filter(|attr| ["from", "source"].contains(&attr.name.to_string().as_str()));

for attr in attrs {
// De-dupe.
if let Some(name) = source.as_ref() {
let msg = format!(
"#[from] | #[source] can only be used once. \
Previously seen on field `{name}`"
);

return Err(spanned_error(msg, attr.name.span()));
}

if attr.name.to_string() == "from" {
if num_fields > 1 {
return Err(spanned_error(
"#[from] can only be used with a single field",
name.span(),
));
let ty = if let Ok(group) = input.as_group() {
let (ty, map) = match group.delimiter() {
Delimiter::Parenthesis => (VariantType::Tuple, parse_tuple_fields(group.stream())?),
Delimiter::Brace => (VariantType::Struct, parse_struct_fields(group.stream())?),
_ => return Err(spanned_error("Unexpected delimiter", group.span())),
};

// Resolve error source.
let num_fields = map.len();
for (key, field) in map.into_iter() {
let attrs = field
.attrs
.iter()
.filter(|attr| ["from", "source"].contains(&attr.name.to_string().as_str()));

for attr in attrs {
// De-dupe.
if let Some(name) = source.as_ref() {
let msg = format!(
"#[from] | #[source] can only be used once. \
Previously seen on field `{name}`"
);

return Err(spanned_error(msg, attr.name.span()));
}

source = ErrorSource::From(key.clone());
} else {
source = ErrorSource::Source(key.clone());
if attr.name.to_string() == "from" {
if num_fields > 1 {
return Err(spanned_error(
"#[from] can only be used with a single field",
name.span(),
));
}

source = ErrorSource::From(key.clone());
} else {
source = ErrorSource::Source(key.clone());
}
}

fields.insert(key, field.path);
}

fields.insert(key, field.path);
}
let _ = input.expect_punct(',');

ty
} else {
VariantType::Unit
};

// #[error] attributes override doc comments
let display = if let Some(mut tree) = attrs
.iter()
.find_map(|attr| (attr.name.to_string() == "error").then_some(attr.tree.clone()))
.and_then(|mut tree| tree.expect_group(Delimiter::Parenthesis).ok())
{
tree.as_lit()?.as_string()?
let mut string = tree.as_lit()?.as_string()?;

if ty == VariantType::Tuple {
// Replace field references
for i in 0..20 {
string = string
.replace(&format!("{{{i}:"), &format!("{{field_{i}:"))
.replace(&format!("{{{i}}}"), &format!("{{field_{i}}}"));
}
}

string
} else {
get_doc_comment(&attrs).join("")
};
Expand Down

0 comments on commit d8ef93f

Please sign in to comment.