Skip to content

Derive Generic FromSql/ToSql #931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion postgres-derive-test/src/composites.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::test_type;
use crate::{test_type, test_type_asymmetric};
use postgres::{Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;
Expand Down Expand Up @@ -238,3 +238,68 @@ fn raw_ident_field() {

test_type(&mut conn, "inventory_item", &[(item, "ROW('foo')")]);
}

#[test]
fn generics() {
#[derive(FromSql, Debug, PartialEq)]
struct InventoryItem<T: Clone, U>
where
U: Clone,
{
name: String,
supplier_id: T,
price: Option<U>,
}

// doesn't make sense to implement derived FromSql on a type with borrows
#[derive(ToSql, Debug, PartialEq)]
#[postgres(name = "InventoryItem")]
struct InventoryItemRef<'a, T: 'a + Clone, U>
where
U: 'a + Clone,
{
name: &'a str,
supplier_id: &'a T,
price: Option<&'a U>,
}

const NAME: &str = "foobar";
const SUPPLIER_ID: i32 = 100;
const PRICE: f64 = 15.50;

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.\"InventoryItem\" AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();

let item = InventoryItemRef {
name: NAME,
supplier_id: &SUPPLIER_ID,
price: Some(&PRICE),
};

let item_null = InventoryItemRef {
name: NAME,
supplier_id: &SUPPLIER_ID,
price: None,
};

test_type_asymmetric(
&mut conn,
"\"InventoryItem\"",
&[
(item, "ROW('foobar', 100, 15.50)"),
(item_null, "ROW('foobar', 100, NULL)"),
],
|t: &InventoryItemRef<i32, f64>, f: &InventoryItem<i32, f64>| {
t.name == f.name.as_str()
&& t.supplier_id == &f.supplier_id
&& t.price == f.price.as_ref()
},
);
}
24 changes: 24 additions & 0 deletions postgres-derive-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,30 @@ where
}
}

pub fn test_type_asymmetric<T, F, S, C>(
conn: &mut Client,
sql_type: &str,
checks: &[(T, S)],
cmp: C,
) where
T: ToSql + Sync,
F: FromSqlOwned,
S: fmt::Display,
C: Fn(&T, &F) -> bool,
{
for &(ref val, ref repr) in checks.iter() {
let stmt = conn
.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
.unwrap();
let result: F = conn.query_one(&stmt, &[]).unwrap().get(0);
assert!(cmp(val, &result));

let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
let result: F = conn.query_one(&stmt, &[val]).unwrap().get(0);
assert!(cmp(val, &result));
}
}

#[test]
fn compile_fail() {
trybuild::TestCases::new().compile_fail("src/compile-fail/*.rs");
Expand Down
26 changes: 25 additions & 1 deletion postgres-derive/src/composites.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use syn::{Error, Ident, Type};
use proc_macro2::Span;
use syn::{
punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type,
TypeParamBound,
};

use crate::overrides::Overrides;

Expand Down Expand Up @@ -26,3 +30,23 @@ impl Field {
})
}
}

pub(crate) fn append_generic_bound(mut generics: Generics, bound: &TypeParamBound) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(param) = param {
param.bounds.push(bound.to_owned())
}
}
generics
}

pub(crate) fn new_derive_path(last: PathSegment) -> Path {
let mut path = Path {
leading_colon: None,
segments: Punctuated::new(),
};
path.segments
.push(Ident::new("postgres_types", Span::call_site()).into());
path.segments.push(last);
path
}
51 changes: 46 additions & 5 deletions postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use std::iter;
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
use syn::{
punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, LifetimeDef,
PathArguments, PathSegment,
};
use syn::{TraitBound, TraitBoundModifier, TypeParamBound};

use crate::accepts;
use crate::composites::Field;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;

Expand Down Expand Up @@ -86,10 +92,13 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
};

let ident = &input.ident;
let (generics, lifetime) = build_generics(&input.generics);
let (impl_generics, _, _) = generics.split_for_impl();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let out = quote! {
impl<'a> postgres_types::FromSql<'a> for #ident {
fn from_sql(_type: &postgres_types::Type, buf: &'a [u8])
-> std::result::Result<#ident,
impl#impl_generics postgres_types::FromSql<#lifetime> for #ident#ty_generics #where_clause {
fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
-> std::result::Result<#ident#ty_generics,
std::boxed::Box<dyn std::error::Error +
std::marker::Sync +
std::marker::Send>> {
Expand Down Expand Up @@ -200,3 +209,35 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
})
}
}

fn build_generics(source: &Generics) -> (Generics, Lifetime) {
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
let lifetime = Lifetime::new("'a", Span::call_site());

let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
out.params.insert(
0,
GenericParam::Lifetime(LifetimeDef::new(lifetime.to_owned())),
);

(out, lifetime)
}

fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound {
let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into();
let mut seg_args = Punctuated::new();
seg_args.push(GenericArgument::Lifetime(lifetime.to_owned()));
path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments {
colon2_token: None,
lt_token: token::Lt::default(),
args: seg_args,
gt_token: token::Gt::default(),
});

TypeParamBound::Trait(TraitBound {
lifetimes: None,
modifier: TraitBoundModifier::None,
paren_token: None,
path: new_derive_path(path_segment),
})
}
21 changes: 18 additions & 3 deletions postgres-derive/src/tosql.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use std::iter;
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
use syn::{
Data, DataStruct, DeriveInput, Error, Fields, Ident, TraitBound, TraitBoundModifier,
TypeParamBound,
};

use crate::accepts;
use crate::composites::Field;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;

Expand Down Expand Up @@ -82,8 +86,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
};

let ident = &input.ident;
let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound());
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let out = quote! {
impl postgres_types::ToSql for #ident {
impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
fn to_sql(&self,
_type: &postgres_types::Type,
buf: &mut postgres_types::private::BytesMut)
Expand Down Expand Up @@ -181,3 +187,12 @@ fn composite_body(fields: &[Field]) -> TokenStream {
std::result::Result::Ok(postgres_types::IsNull::No)
}
}

fn new_tosql_bound() -> TypeParamBound {
TypeParamBound::Trait(TraitBound {
lifetimes: None,
modifier: TraitBoundModifier::None,
paren_token: None,
path: new_derive_path(Ident::new("ToSql", Span::call_site()).into()),
})
}