Skip to content

Commit

Permalink
Implement derives for generic wrapper types
Browse files Browse the repository at this point in the history
  • Loading branch information
Oliver Scherer committed Jun 30, 2020
1 parent ed849ab commit c9764cd
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 15 deletions.
51 changes: 36 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,13 @@ impl NumTraits {
pub fn from_primitive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);

let import = NumTraits::new(&ast);

let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
quote! {
impl #import::FromPrimitive for #name {
impl #impl_ #import::FromPrimitive for #name #type_ #where_ #inner_ty: #import::FromPrimitive {
fn from_i64(n: i64) -> Option<Self> {
<#inner_ty as #import::FromPrimitive>::from_i64(n).map(#name)
}
Expand Down Expand Up @@ -320,7 +321,7 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
};

quote! {
impl #import::FromPrimitive for #name {
impl #impl_ #import::FromPrimitive for #name #type_ #where_ {
#[allow(trivial_numeric_casts)]
fn from_i64(#from_i64_var: i64) -> Option<Self> {
#(#clauses else)* {
Expand Down Expand Up @@ -390,12 +391,13 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
pub fn to_primitive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);

let import = NumTraits::new(&ast);

let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
quote! {
impl #import::ToPrimitive for #name {
impl #impl_ #import::ToPrimitive for #name #type_ #where_ #inner_ty: #import::ToPrimitive {
fn to_i64(&self) -> Option<i64> {
<#inner_ty as #import::ToPrimitive>::to_i64(&self.0)
}
Expand Down Expand Up @@ -481,7 +483,7 @@ pub fn to_primitive(input: TokenStream) -> TokenStream {
};

quote! {
impl #import::ToPrimitive for #name {
impl #impl_ #import::ToPrimitive for #name #type_ #where_ {
#[allow(trivial_numeric_casts)]
fn to_i64(&self) -> Option<i64> {
#match_expr
Expand Down Expand Up @@ -511,33 +513,34 @@ const NEWTYPE_ONLY: &str = "This trait can only be derived for newtypes";
pub fn num_ops(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
let impl_ = quote! {
impl ::std::ops::Add for #name {
impl #impl_ ::std::ops::Add for #name #type_ #where_ #inner_ty: ::std::ops::Add<Output = #inner_ty> {
type Output = Self;
fn add(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Add>::add(self.0, other.0))
}
}
impl ::std::ops::Sub for #name {
impl #impl_ ::std::ops::Sub for #name #type_ #where_ #inner_ty: ::std::ops::Sub<Output = #inner_ty> {
type Output = Self;
fn sub(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Sub>::sub(self.0, other.0))
}
}
impl ::std::ops::Mul for #name {
impl #impl_ ::std::ops::Mul for #name #type_ #where_ #inner_ty: ::std::ops::Mul<Output = #inner_ty> {
type Output = Self;
fn mul(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Mul>::mul(self.0, other.0))
}
}
impl ::std::ops::Div for #name {
impl #impl_ ::std::ops::Div for #name #type_ #where_ #inner_ty: ::std::ops::Div<Output = #inner_ty> {
type Output = Self;
fn div(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Div>::div(self.0, other.0))
}
}
impl ::std::ops::Rem for #name {
impl #impl_ ::std::ops::Rem for #name #type_ #where_ #inner_ty: ::std::ops::Rem<Output = #inner_ty> {
type Output = Self;
fn rem(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Rem>::rem(self.0, other.0))
Expand All @@ -555,13 +558,16 @@ pub fn num_ops(input: TokenStream) -> TokenStream {
pub fn num_cast(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
let fn_param = proc_macro2::Ident::new("FROM_T", name.span());

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::NumCast for #name {
fn from<T: #import::ToPrimitive>(n: T) -> Option<Self> {
impl #impl_ #import::NumCast for #name #type_ #where_ #inner_ty: #import::NumCast {
#[allow(non_camel_case_types)]
fn from<#fn_param: #import::ToPrimitive>(n: #fn_param) -> Option<Self> {
<#inner_ty as #import::NumCast>::from(n).map(#name)
}
}
Expand All @@ -577,12 +583,13 @@ pub fn num_cast(input: TokenStream) -> TokenStream {
pub fn zero(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Zero for #name {
impl #impl_ #import::Zero for #name #type_ #where_ #inner_ty: #import::Zero {
fn zero() -> Self {
#name(<#inner_ty as #import::Zero>::zero())
}
Expand All @@ -602,12 +609,13 @@ pub fn zero(input: TokenStream) -> TokenStream {
pub fn one(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::One for #name {
impl #impl_ #import::One for #name #type_ #where_ #inner_ty: #import::One + PartialEq {
fn one() -> Self {
#name(<#inner_ty as #import::One>::one())
}
Expand All @@ -620,19 +628,31 @@ pub fn one(input: TokenStream) -> TokenStream {
import.wrap("One", &name, impl_).into()
}

fn split_for_impl(
generics: &syn::Generics,
) -> (syn::ImplGenerics, syn::TypeGenerics, impl quote::ToTokens) {
let (impl_, type_, where_) = generics.split_for_impl();
let where_ = match where_ {
Some(where_) => quote! { #where_, },
None => quote! { where },
};
(impl_, type_, where_)
}

/// Derives [`num_traits::Num`][num] for newtypes. The inner type must already implement `Num`.
///
/// [num]: https://docs.rs/num-traits/0.2/num_traits/trait.Num.html
#[proc_macro_derive(Num, attributes(num_traits))]
pub fn num(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Num for #name {
impl #impl_ #import::Num for #name #type_ #where_ #inner_ty: #import::Num {
type FromStrRadixErr = <#inner_ty as #import::Num>::FromStrRadixErr;
fn from_str_radix(s: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
<#inner_ty as #import::Num>::from_str_radix(s, radix).map(#name)
Expand All @@ -651,12 +671,13 @@ pub fn num(input: TokenStream) -> TokenStream {
pub fn float(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Float for #name {
impl #impl_ #import::Float for #name #type_ #where_ #inner_ty: #import::Float {
fn nan() -> Self {
#name(<#inner_ty as #import::Float>::nan())
}
Expand Down
97 changes: 97 additions & 0 deletions tests/generic_newtype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
extern crate num as num_renamed;
#[macro_use]
extern crate num_derive;

use crate::num_renamed::{Float, FromPrimitive, Num, NumCast, One, ToPrimitive, Zero};
use std::ops::Neg;

#[derive(
Debug,
Clone,
Copy,
PartialEq,
PartialOrd,
ToPrimitive,
FromPrimitive,
NumOps,
NumCast,
One,
Zero,
Num,
Float,
)]
struct MyThing<T: Cake>(T)
where
T: Lie;

trait Cake {}
trait Lie {}

impl Cake for f32 {}
impl Lie for f32 {}

impl<T: Neg<Output = T> + Cake + Lie> Neg for MyThing<T> {
type Output = Self;
fn neg(self) -> Self {
MyThing(self.0.neg())
}
}

#[test]
fn test_from_primitive() {
assert_eq!(MyThing::from_u32(25), Some(MyThing(25.0)));
}

#[test]
fn test_from_primitive_128() {
assert_eq!(
MyThing::from_i128(std::i128::MIN),
Some(MyThing((-2.0).powi(127)))
);
}

#[test]
fn test_to_primitive() {
assert_eq!(MyThing(25.0).to_u32(), Some(25));
}

#[test]
fn test_to_primitive_128() {
let f: MyThing<f32> = MyThing::from_f32(std::f32::MAX).unwrap();
assert_eq!(f.to_i128(), None);
assert_eq!(f.to_u128(), Some(0xffff_ff00_0000_0000_0000_0000_0000_0000));
}

#[test]
fn test_num_ops() {
assert_eq!(MyThing(25.0) + MyThing(10.0), MyThing(35.0));
assert_eq!(MyThing(25.0) - MyThing(10.0), MyThing(15.0));
assert_eq!(MyThing(25.0) * MyThing(2.0), MyThing(50.0));
assert_eq!(MyThing(25.0) / MyThing(10.0), MyThing(2.5));
assert_eq!(MyThing(25.0) % MyThing(10.0), MyThing(5.0));
}

#[test]
fn test_num_cast() {
assert_eq!(<MyThing<f32> as NumCast>::from(25u8), Some(MyThing(25.0)));
}

#[test]
fn test_zero() {
assert_eq!(MyThing::zero(), MyThing(0.0));
}

#[test]
fn test_one() {
assert_eq!(MyThing::one(), MyThing(1.0));
}

#[test]
fn test_num() {
assert_eq!(MyThing::from_str_radix("25", 10).ok(), Some(MyThing(25.0)));
}

#[test]
fn test_float() {
assert_eq!(MyThing(4.0).log(MyThing(2.0)), MyThing(2.0));
}

0 comments on commit c9764cd

Please sign in to comment.