diff --git a/src/from.rs b/src/from.rs new file mode 100644 index 00000000..b1048237 --- /dev/null +++ b/src/from.rs @@ -0,0 +1,20 @@ +//! Generate `From` implementations to convert variants to the top-level enum. +use quote::quote; +use syn::{Ident, ImplGenerics, TypeGenerics, WhereClause}; + +pub fn generate_from_trait_impl( + type_name: &Ident, + impl_generics: &ImplGenerics, + ty_generics: &TypeGenerics, + where_clause: &Option<&WhereClause>, + variant_name: &Ident, + struct_name: &Ident, +) -> proc_macro2::TokenStream { + quote! { + impl #impl_generics From<#struct_name #ty_generics> for #type_name #ty_generics #where_clause { + fn from(variant: #struct_name #ty_generics) -> Self { + Self::#variant_name(variant) + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index af9355e4..fd6a7eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use attributes::{IdentList, NestedMetaList}; use darling::FromMeta; +use from::generate_from_trait_impl; use itertools::Itertools; use macros::generate_all_map_macros; use proc_macro::TokenStream; @@ -13,6 +14,7 @@ use syn::{ }; mod attributes; +mod from; mod macros; mod naming; mod utils; @@ -516,6 +518,19 @@ pub fn superstruct(args: TokenStream, input: TokenStream) -> TokenStream { ); } + // Generate trait implementations. + for (variant_name, struct_name) in variant_names.iter().zip_eq(&struct_names) { + let from_impl = generate_from_trait_impl( + type_name, + impl_generics, + ty_generics, + where_clause, + variant_name, + struct_name, + ); + output_items.push(from_impl.into()); + } + TokenStream::from_iter(output_items) } diff --git a/tests/from.rs b/tests/from.rs new file mode 100644 index 00000000..45f5d177 --- /dev/null +++ b/tests/from.rs @@ -0,0 +1,34 @@ +use std::fmt::Display; +use superstruct::superstruct; + +#[superstruct(variants(Good, Bad), variant_attributes(derive(Debug, PartialEq)))] +#[derive(Debug, PartialEq)] +pub struct Message { + #[superstruct(getter(copy))] + id: u64, + #[superstruct(only(Good))] + good: T, + #[superstruct(only(Bad))] + bad: T, +} + +#[test] +fn generic_from() { + let message_good_variant = MessageGood { + id: 0, + good: "hello", + }; + let message_bad_variant = MessageBad { + id: 1, + bad: "noooooo", + }; + + let message_good = Message::from(message_good_variant); + let message_bad = Message::from(message_bad_variant); + + assert_eq!(message_good.id(), 0); + assert_eq!(*message_good.good().unwrap(), "hello"); + + assert_eq!(message_bad.id(), 1); + assert_eq!(*message_bad.bad().unwrap(), "noooooo"); +}