Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What's the idiomatic way to constrain a traits assosciated type with arithmetic? #201

Open
Ben-PH opened this issue Sep 25, 2023 · 3 comments

Comments

@Ben-PH
Copy link

Ben-PH commented Sep 25, 2023

I have a trait to read data into a passed-in byte-array:

pub trait ReadState {
    type LEN: ArrayLength;
    fn read_state(&mut self, buf: &mut GenericArray<u8, Self::LEN>);
}

and I'd like to read the state of a bool array. it hase a width and height, so I need to read in width * height bits, so i need a byte-array of (width * height + 7) / 8. This is the solution I have:

impl<Width: ArrayLength + Mul<Height>, Height: ArrayLength> ReadState
    for KeyDriver<InP, OutP, Width, Height, D>
where
    Width: Mul<Height>,
    <Width as Mul<Height>>::Output: Add<U7>,
    <<Width as Mul<Height>>::Output as Add<U7>>::Output: Div<U8>,
    <<<Width as Mul<Height>>::Output as Add<U7>>::Output as Div<U8>>::Output: ArrayLength,
{

This works, but is there a cleaner way to do it? I tried doing something with the op! macro, but got massive recursive depth issues.

@paholg
Copy link
Owner

paholg commented Sep 25, 2023

This is how you can write that using the op! macro (note I've added enough shims to get this to compile):

use std::{
    marker::PhantomData,
    ops::{Add, Div, Mul},
};

use generic_array::{ArrayLength, GenericArray};
use typenum::{op, U7, U8};

struct InP;
struct OutP;
struct D;

struct KeyDriver<In, Out, Width, Height, D> {
    _marker: PhantomData<(In, Out, Width, Height, D)>,
}

pub trait ReadState {
    type LEN: ArrayLength;
    fn read_state(&mut self, buf: &mut GenericArray<u8, Self::LEN>);
}

impl<Width, Height> ReadState for KeyDriver<InP, OutP, Width, Height, D>
where
    Height: ArrayLength,
    Width: ArrayLength + Mul<Height>,
    op!(Width * Height): Add<U7>,
    op!(Width * Height + U7): Div<U8>,
    op!((Width * Height + U7) / U8): ArrayLength,
{
    type LEN = op!((Width * Height + U7) / U8);

    fn read_state(&mut self, buf: &mut GenericArray<u8, Self::LEN>) {}
}

fn main() {}

@paholg
Copy link
Owner

paholg commented Sep 25, 2023

I am curious if you could give an example of what you tried that resulted in errors; maybe it's an opportunity to improve documentation or behavior of the macro.

@Ben-PH
Copy link
Author

Ben-PH commented Oct 5, 2023

at first I just had <Width as Mul<Height>>::Output: ArrayLength, and that was my LEN assosciated type. for context, I'm implementing a read-into-byte-array generic method, but this specific impl is for bits, so i need a byte-arary that is 1/8 len, but rounded up, not down.

I tried using op! in the LEN, which resulted in compile-time recursion limits exceeding (I got to something like 8k limit before I realized that was not that answer).

For documentation, something that makes easy discoverability that integer ops have to be constrained for each and every operation. Maybe something like this in E.g.s

/// Suppose you would want to define a const-generic, but it s value needed to have a compile-time
/// calculation. This could include something like a reading a boolean matrix into a byte-array.
trait TypeNumOps {
    type TypeNumRes: Unsigned;
    fn typenum_calc() -> Self::TypeNumRes;
}

/// Let's create a struct that needs compile-time calculation of a generic
struct BitMatrixIntoByteArray<X: Unsigned, Y: Unsigned> {
    _x: PhantomData<X>,
    _y: PhantomData<Y>,
}

/// In this case, we'll be doing a mul => addition => division operation sequence:
/// ((width * height) + 7) / 8
impl<WIDTH: Unsigned, HEIGHT: typenum::Unsigned> TypeNumOps for BitMatrixIntoByteArray<WIDTH, HEIGHT>
where
    // 1. the first op is width * height, so width must implement multiplication by height
    WIDTH: Mul<HEIGHT>,
    // 2. the second op is just the first opp + 7 => first opp must impl Add<U7>
    op!(WIDTH * HEIGHT): Add<U7>,
    // 3. The pattern continues: the result of the previous op must impl Div<U8>
    op!((WIDTH * HEIGHT) + U7): Div<U8>,
    // and the result of all the ops must impl Unsigned, and we're done!
    op!(((WIDTH * HEIGHT) + U7) / U8): Unsigned,
{
    
    /// This is useful for calculating the number of bytes needed to store width * height bits:
    ///  1. width * height => total_bitlen
    ///  2. dimension + 7 => byte_friendly_bitlen: so when we divide, we don't truncate a non-full tail-byte
    ///  3. byte_friendly_bitlen / 8 => byte_buff_len
    ///  4. the final result must have the same constraints as OpResult
    type TypeNumRes = op!(((WIDTH * HEIGHT) + U7) / U8);

    fn typenum_calc() -> Self::TypeNumRes {
        todo!()
    }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants