Skip to content

Commit

Permalink
Add SIMD backend arithmetic (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 15, 2024
1 parent e230709 commit 8d61196
Show file tree
Hide file tree
Showing 9 changed files with 1,157 additions and 3 deletions.
26 changes: 26 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ edition.workspace = true
[dependencies]
blake2.workspace = true
blake3.workspace = true
bytemuck = { workspace = true, features = ["derive"] }
cfg-if = "1.0.0"
derivative.workspace = true
hex.workspace = true
itertools.workspace = true
num-traits.workspace = true
thiserror.workspace = true
bytemuck = { workspace = true, features = ["derive"] }
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
thiserror.workspace = true
tracing.workspace = true

[dev-dependencies]
aligned = "0.4.2"
criterion = { version = "0.5.1", features = ["html_reports"] }
test-log = { version = "0.2.15", features = ["trace"] }
tracing-subscriber = "0.3.18"
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use super::poly::circle::PolyOps;
#[cfg(target_arch = "x86_64")]
pub mod avx512;
pub mod cpu;
pub mod simd;

pub trait Backend:
Copy
Expand Down
221 changes: 221 additions & 0 deletions crates/prover/src/core/backend/simd/cm31.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use std::array;
use std::ops::{Add, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

use super::m31::{PackedM31, N_LANES};
use crate::core::fields::cm31::CM31;
use crate::core::fields::FieldExpOps;

/// SIMD implementation of [`CM31`].
#[derive(Copy, Clone, Debug)]
pub struct PackedCM31(pub [PackedM31; 2]);

impl PackedCM31 {
/// Constructs a new instance with all vector elements set to `value`.
pub fn broadcast(value: CM31) -> Self {
Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)])
}

/// Returns all `a` values such that each vector element is represented as `a + bi`.
pub fn a(&self) -> PackedM31 {
self.0[0]
}

/// Returns all `b` values such that each vector element is represented as `a + bi`.
pub fn b(&self) -> PackedM31 {
self.0[1]
}

pub fn to_array(&self) -> [CM31; N_LANES] {
let a = self.a().to_array();
let b = self.b().to_array();
array::from_fn(|i| CM31(a[i], b[i]))
}

pub fn from_array(values: [CM31; N_LANES]) -> Self {
Self([
PackedM31::from_array(values.map(|v| v.0)),
PackedM31::from_array(values.map(|v| v.1)),
])
}

/// Interleaves two vectors.
pub fn interleave(self, other: Self) -> (Self, Self) {
let Self([a_evens, b_evens]) = self;
let Self([a_odds, b_odds]) = other;
let (a_lhs, a_rhs) = a_evens.interleave(a_odds);
let (b_lhs, b_rhs) = b_evens.interleave(b_odds);
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs]))
}

/// Deinterleaves two vectors.
pub fn deinterleave(self, other: Self) -> (Self, Self) {
let Self([a_self, b_self]) = self;
let Self([a_other, b_other]) = other;
let (a_evens, a_odds) = a_self.deinterleave(a_other);
let (b_evens, b_odds) = b_self.deinterleave(b_other);
(Self([a_evens, b_evens]), Self([a_odds, b_odds]))
}

/// Doubles each element in the vector.
pub fn double(self) -> Self {
let Self([a, b]) = self;
Self([a.double(), b.double()])
}
}

impl Add for PackedCM31 {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
Self([self.a() + rhs.a(), self.b() + rhs.b()])
}
}

impl Sub for PackedCM31 {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
Self([self.a() - rhs.a(), self.b() - rhs.b()])
}
}

impl Mul for PackedCM31 {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
// Compute using Karatsuba.
let ac = self.a() * rhs.a();
let bd = self.b() * rhs.b();
// Computes (a + b) * (c + d).
let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b());
// (ac - bd) + (ad + bc)i.
Self([ac - bd, ab_t_cd - ac - bd])
}
}

impl Zero for PackedCM31 {
fn zero() -> Self {
Self([PackedM31::zero(), PackedM31::zero()])
}

fn is_zero(&self) -> bool {
self.a().is_zero() && self.b().is_zero()
}
}

impl One for PackedCM31 {
fn one() -> Self {
Self([PackedM31::one(), PackedM31::zero()])
}
}

impl MulAssign for PackedCM31 {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}

impl FieldExpOps for PackedCM31 {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
// 1 / (a + bi) = (a - bi) / (a^2 + b^2).
Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse()
}
}

impl Add<PackedM31> for PackedCM31 {
type Output = Self;

fn add(self, rhs: PackedM31) -> Self::Output {
Self([self.a() + rhs, self.b()])
}
}

impl Sub<PackedM31> for PackedCM31 {
type Output = Self;

fn sub(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a - rhs, b])
}
}

impl Mul<PackedM31> for PackedCM31 {
type Output = Self;

fn mul(self, rhs: PackedM31) -> Self::Output {
let Self([a, b]) = self;
Self([a * rhs, b * rhs])
}
}

impl Neg for PackedCM31 {
type Output = Self;

fn neg(self) -> Self::Output {
let Self([a, b]) = self;
Self([-a, -b])
}
}

#[cfg(test)]
mod tests {
use std::array;

use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::simd::cm31::PackedCM31;

#[test]
fn addition_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);

let res = packed_lhs + packed_rhs;

assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i]));
}

#[test]
fn subtraction_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);

let res = packed_lhs - packed_rhs;

assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i]));
}

#[test]
fn multiplication_works() {
let mut rng = SmallRng::seed_from_u64(0);
let lhs = rng.gen();
let rhs = rng.gen();
let packed_lhs = PackedCM31::from_array(lhs);
let packed_rhs = PackedCM31::from_array(rhs);

let res = packed_lhs * packed_rhs;

assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i]));
}

#[test]
fn negation_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values = rng.gen();
let packed_values = PackedCM31::from_array(values);

let res = -packed_values;

assert_eq!(res.to_array(), values.map(|v| -v));
}
}
Loading

0 comments on commit 8d61196

Please sign in to comment.