diff --git a/src/lib.rs b/src/lib.rs index 138b752..3a954f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -232,12 +232,13 @@ impl NumTraits { pub fn from_primitive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let import = NumTraits::new(&ast); let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) { quote! { - impl #import::FromPrimitive for #name { + impl #impl_ #import::FromPrimitive for #name #type_ #where_ #inner_ty: #import::FromPrimitive { fn from_i64(n: i64) -> Option { <#inner_ty as #import::FromPrimitive>::from_i64(n).map(#name) } @@ -320,7 +321,7 @@ pub fn from_primitive(input: TokenStream) -> TokenStream { }; quote! { - impl #import::FromPrimitive for #name { + impl #impl_ #import::FromPrimitive for #name #type_ #where_ { #[allow(trivial_numeric_casts)] fn from_i64(#from_i64_var: i64) -> Option { #(#clauses else)* { @@ -390,12 +391,13 @@ pub fn from_primitive(input: TokenStream) -> TokenStream { pub fn to_primitive(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let import = NumTraits::new(&ast); let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) { quote! { - impl #import::ToPrimitive for #name { + impl #impl_ #import::ToPrimitive for #name #type_ #where_ #inner_ty: #import::ToPrimitive { fn to_i64(&self) -> Option { <#inner_ty as #import::ToPrimitive>::to_i64(&self.0) } @@ -481,7 +483,7 @@ pub fn to_primitive(input: TokenStream) -> TokenStream { }; quote! { - impl #import::ToPrimitive for #name { + impl #impl_ #import::ToPrimitive for #name #type_ #where_ { #[allow(trivial_numeric_casts)] fn to_i64(&self) -> Option { #match_expr @@ -511,33 +513,34 @@ const NEWTYPE_ONLY: &str = "This trait can only be derived for newtypes"; pub fn num_ops(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); let impl_ = quote! { - impl ::std::ops::Add for #name { + impl #impl_ ::std::ops::Add for #name #type_ #where_ #inner_ty: ::std::ops::Add { type Output = Self; fn add(self, other: Self) -> Self { #name(<#inner_ty as ::std::ops::Add>::add(self.0, other.0)) } } - impl ::std::ops::Sub for #name { + impl #impl_ ::std::ops::Sub for #name #type_ #where_ #inner_ty: ::std::ops::Sub { type Output = Self; fn sub(self, other: Self) -> Self { #name(<#inner_ty as ::std::ops::Sub>::sub(self.0, other.0)) } } - impl ::std::ops::Mul for #name { + impl #impl_ ::std::ops::Mul for #name #type_ #where_ #inner_ty: ::std::ops::Mul { type Output = Self; fn mul(self, other: Self) -> Self { #name(<#inner_ty as ::std::ops::Mul>::mul(self.0, other.0)) } } - impl ::std::ops::Div for #name { + impl #impl_ ::std::ops::Div for #name #type_ #where_ #inner_ty: ::std::ops::Div { type Output = Self; fn div(self, other: Self) -> Self { #name(<#inner_ty as ::std::ops::Div>::div(self.0, other.0)) } } - impl ::std::ops::Rem for #name { + impl #impl_ ::std::ops::Rem for #name #type_ #where_ #inner_ty: ::std::ops::Rem { type Output = Self; fn rem(self, other: Self) -> Self { #name(<#inner_ty as ::std::ops::Rem>::rem(self.0, other.0)) @@ -555,13 +558,16 @@ pub fn num_ops(input: TokenStream) -> TokenStream { pub fn num_cast(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); + let fn_param = proc_macro2::Ident::new("FROM_T", name.span()); let import = NumTraits::new(&ast); let impl_ = quote! { - impl #import::NumCast for #name { - fn from(n: T) -> Option { + impl #impl_ #import::NumCast for #name #type_ #where_ #inner_ty: #import::NumCast { + #[allow(non_camel_case_types)] + fn from<#fn_param: #import::ToPrimitive>(n: #fn_param) -> Option { <#inner_ty as #import::NumCast>::from(n).map(#name) } } @@ -577,12 +583,13 @@ pub fn num_cast(input: TokenStream) -> TokenStream { pub fn zero(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); let import = NumTraits::new(&ast); let impl_ = quote! { - impl #import::Zero for #name { + impl #impl_ #import::Zero for #name #type_ #where_ #inner_ty: #import::Zero { fn zero() -> Self { #name(<#inner_ty as #import::Zero>::zero()) } @@ -602,12 +609,13 @@ pub fn zero(input: TokenStream) -> TokenStream { pub fn one(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); let import = NumTraits::new(&ast); let impl_ = quote! { - impl #import::One for #name { + impl #impl_ #import::One for #name #type_ #where_ #inner_ty: #import::One + PartialEq { fn one() -> Self { #name(<#inner_ty as #import::One>::one()) } @@ -620,6 +628,17 @@ pub fn one(input: TokenStream) -> TokenStream { import.wrap("One", &name, impl_).into() } +fn split_for_impl( + generics: &syn::Generics, +) -> (syn::ImplGenerics, syn::TypeGenerics, impl quote::ToTokens) { + let (impl_, type_, where_) = generics.split_for_impl(); + let where_ = match where_ { + Some(where_) => quote! { #where_, }, + None => quote! { where }, + }; + (impl_, type_, where_) +} + /// Derives [`num_traits::Num`][num] for newtypes. The inner type must already implement `Num`. /// /// [num]: https://docs.rs/num-traits/0.2/num_traits/trait.Num.html @@ -627,12 +646,13 @@ pub fn one(input: TokenStream) -> TokenStream { pub fn num(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); let import = NumTraits::new(&ast); let impl_ = quote! { - impl #import::Num for #name { + impl #impl_ #import::Num for #name #type_ #where_ #inner_ty: #import::Num { type FromStrRadixErr = <#inner_ty as #import::Num>::FromStrRadixErr; fn from_str_radix(s: &str, radix: u32) -> Result { <#inner_ty as #import::Num>::from_str_radix(s, radix).map(#name) @@ -651,12 +671,13 @@ pub fn num(input: TokenStream) -> TokenStream { pub fn float(input: TokenStream) -> TokenStream { let ast: syn::DeriveInput = syn::parse(input).unwrap(); let name = &ast.ident; + let (impl_, type_, where_) = split_for_impl(&ast.generics); let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY); let import = NumTraits::new(&ast); let impl_ = quote! { - impl #import::Float for #name { + impl #impl_ #import::Float for #name #type_ #where_ #inner_ty: #import::Float { fn nan() -> Self { #name(<#inner_ty as #import::Float>::nan()) } diff --git a/tests/generic_newtype.rs b/tests/generic_newtype.rs new file mode 100644 index 0000000..725f14f --- /dev/null +++ b/tests/generic_newtype.rs @@ -0,0 +1,97 @@ +extern crate num as num_renamed; +#[macro_use] +extern crate num_derive; + +use crate::num_renamed::{Float, FromPrimitive, Num, NumCast, One, ToPrimitive, Zero}; +use std::ops::Neg; + +#[derive( + Debug, + Clone, + Copy, + PartialEq, + PartialOrd, + ToPrimitive, + FromPrimitive, + NumOps, + NumCast, + One, + Zero, + Num, + Float, +)] +struct MyThing(T) +where + T: Lie; + +trait Cake {} +trait Lie {} + +impl Cake for f32 {} +impl Lie for f32 {} + +impl + Cake + Lie> Neg for MyThing { + type Output = Self; + fn neg(self) -> Self { + MyThing(self.0.neg()) + } +} + +#[test] +fn test_from_primitive() { + assert_eq!(MyThing::from_u32(25), Some(MyThing(25.0))); +} + +#[test] +fn test_from_primitive_128() { + assert_eq!( + MyThing::from_i128(std::i128::MIN), + Some(MyThing((-2.0).powi(127))) + ); +} + +#[test] +fn test_to_primitive() { + assert_eq!(MyThing(25.0).to_u32(), Some(25)); +} + +#[test] +fn test_to_primitive_128() { + let f: MyThing = MyThing::from_f32(std::f32::MAX).unwrap(); + assert_eq!(f.to_i128(), None); + assert_eq!(f.to_u128(), Some(0xffff_ff00_0000_0000_0000_0000_0000_0000)); +} + +#[test] +fn test_num_ops() { + assert_eq!(MyThing(25.0) + MyThing(10.0), MyThing(35.0)); + assert_eq!(MyThing(25.0) - MyThing(10.0), MyThing(15.0)); + assert_eq!(MyThing(25.0) * MyThing(2.0), MyThing(50.0)); + assert_eq!(MyThing(25.0) / MyThing(10.0), MyThing(2.5)); + assert_eq!(MyThing(25.0) % MyThing(10.0), MyThing(5.0)); +} + +#[test] +fn test_num_cast() { + assert_eq!( as NumCast>::from(25u8), Some(MyThing(25.0))); +} + +#[test] +fn test_zero() { + assert_eq!(MyThing::zero(), MyThing(0.0)); +} + +#[test] +fn test_one() { + assert_eq!(MyThing::one(), MyThing(1.0)); +} + +#[test] +fn test_num() { + assert_eq!(MyThing::from_str_radix("25", 10).ok(), Some(MyThing(25.0))); +} + +#[test] +fn test_float() { + assert_eq!(MyThing(4.0).log(MyThing(2.0)), MyThing(2.0)); +}