forked from RustPython/RustPython
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytraverse.rs
138 lines (129 loc) · 4.35 KB
/
pytraverse.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, DeriveInput, Field, Meta, MetaList, NestedMeta, Result};
struct TraverseAttr {
/// set to `true` if the attribute is `#[pytraverse(skip)]`
skip: bool,
}
const ATTR_TRAVERSE: &str = "pytraverse";
/// get the `#[pytraverse(..)]` attribute from the struct
fn valid_get_traverse_attr_from_meta_list(list: &MetaList) -> Result<TraverseAttr> {
let find_skip_and_only_skip = || {
let len = list.nested.len();
if len != 1 {
return None;
}
let mut iter = list.nested.iter();
// we have checked the length, so unwrap is safe
let first_arg = iter.next().unwrap();
let skip = match first_arg {
NestedMeta::Meta(Meta::Path(path)) => match path.is_ident("skip") {
true => true,
false => return None,
},
_ => return None,
};
Some(skip)
};
let skip = find_skip_and_only_skip().ok_or_else(|| {
err_span!(
list,
"only support attr is #[pytraverse(skip)], got arguments: {:?}",
list.nested
)
})?;
Ok(TraverseAttr { skip })
}
/// only accept `#[pytraverse(skip)]` for now
fn pytraverse_arg(attr: &Attribute) -> Option<Result<TraverseAttr>> {
if !attr.path.is_ident(ATTR_TRAVERSE) {
return None;
}
let ret = || {
let parsed = attr.parse_meta()?;
if let Meta::List(list) = parsed {
valid_get_traverse_attr_from_meta_list(&list)
} else {
bail_span!(attr, "pytraverse must be a list, like #[pytraverse(skip)]")
}
};
Some(ret())
}
fn field_to_traverse_code(field: &Field) -> Result<TokenStream> {
let pytraverse_attrs = field
.attrs
.iter()
.filter_map(pytraverse_arg)
.collect::<std::result::Result<Vec<_>, _>>()?;
let do_trace = if pytraverse_attrs.len() > 1 {
bail_span!(
field,
"found multiple #[pytraverse] attributes on the same field, expect at most one"
)
} else if pytraverse_attrs.is_empty() {
// default to always traverse every field
true
} else {
!pytraverse_attrs[0].skip
};
let name = field.ident.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
field.clone(),
"Field should have a name in non-tuple struct",
)
})?;
if do_trace {
Ok(quote!(
::rustpython_vm::object::Traverse::traverse(&self.#name, tracer_fn);
))
} else {
Ok(quote!())
}
}
/// not trace corresponding field
fn gen_trace_code(item: &mut DeriveInput) -> Result<TokenStream> {
match &mut item.data {
syn::Data::Struct(s) => {
let fields = &mut s.fields;
match fields {
syn::Fields::Named(ref mut fields) => {
let res: Vec<TokenStream> = fields
.named
.iter_mut()
.map(|f| -> Result<TokenStream> { field_to_traverse_code(f) })
.collect::<Result<_>>()?;
let res = res.into_iter().collect::<TokenStream>();
Ok(res)
}
syn::Fields::Unnamed(fields) => {
let res: TokenStream = (0..fields.unnamed.len())
.map(|i| {
let i = syn::Index::from(i);
quote!(
::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn);
)
})
.collect();
Ok(res)
}
_ => Err(syn::Error::new_spanned(
fields,
"Only named and unnamed fields are supported",
)),
}
}
_ => Err(syn::Error::new_spanned(item, "Only structs are supported")),
}
}
pub(crate) fn impl_pytraverse(mut item: DeriveInput) -> Result<TokenStream> {
let trace_code = gen_trace_code(&mut item)?;
let ty = &item.ident;
let ret = quote! {
unsafe impl ::rustpython_vm::object::Traverse for #ty {
fn traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
#trace_code
}
}
};
Ok(ret)
}