Skip to content

Commit

Permalink
Add mul, add, sub, neg, eq of Scalar
Browse files Browse the repository at this point in the history
Multliplication took the place of multiplying with point. To multiply a
scalar with a point use multiply operator of the point instead.
  • Loading branch information
tgalal committed Jan 6, 2023
1 parent 7d5d569 commit 57361ba
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 17 deletions.
28 changes: 25 additions & 3 deletions src/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::prelude::*;
use curve25519_dalek::scalar::Scalar as _Scalar;
use crate::ristretto::{RistrettoPoint};
use pyo3::basic::CompareOp;

#[pyclass]
pub struct Scalar(pub _Scalar);
Expand All @@ -12,8 +12,30 @@ impl Scalar {
Scalar(_Scalar::from(x))
}

pub fn __mul__(&self, p : &RistrettoPoint) -> RistrettoPoint {
RistrettoPoint(self.0 * p.0)
pub fn __mul__(&self, p : &Scalar) -> Scalar {
Scalar(self.0 * p.0)
}

pub fn __add__(&self, p : &Scalar) -> Scalar {
Scalar(self.0 + p.0)
}

pub fn __sub__(&self, p : &Scalar) -> Scalar {
Scalar(self.0 - p.0)
}

pub fn __neg__(&self) -> Scalar {
Scalar(-self.0)
}

// Overriding comparison operators, currently only supporting == and !=
fn __richcmp__(&self, other: PyRef<Scalar>, op: CompareOp) -> Py<PyAny> {
let py = other.py();
match op {
CompareOp::Eq => (self.0 == other.0).into_py(py),
CompareOp::Ne => (self.0 != other.0).into_py(py),
_ => py.NotImplemented(),
}
}
}

Expand Down
19 changes: 5 additions & 14 deletions tests/test_ristretto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,14 @@
from curve25519_dalek.scalar import Scalar
from curve25519_dalek.constants import RISTRETTO_BASEPOINT_POINT

def test_scalarmult_ristrettopoint_works_both_ways():
def test_scalarmult():
P = RISTRETTO_BASEPOINT_POINT
s = Scalar.from_u64(999)
P1 = P * s
P2 = s * P

assert P1.compress().as_bytes() == P2.compress().as_bytes()

def test_impl_sum():
BASE = RISTRETTO_BASEPOINT_POINT
s1 = Scalar.from_u64(999)
P1 = BASE * s1

s2 = Scalar.from_u64(333);
P2 = BASE * s2;

vec = [P1, P2]
assert P1.compress().as_bytes() == bytes([
70, 101, 76, 35, 4, 33, 130, 159, 62, 231, 63, 205, 135, 227, 60, 26,
147, 227, 5, 110, 18, 24, 124, 104, 111, 26, 24, 111, 8, 181, 54, 33
])

def test_elligator_vs_ristretto_sage():
data = bytes([184, 249, 135, 49, 253, 123, 89, 113, 67, 160, 6, 239,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from curve25519_dalek.ristretto import RistrettoPoint, CompressedRistretto
from curve25519_dalek.scalar import Scalar

zero = Scalar.from_u64(0)
zero2 = Scalar.from_u64(0)
one = Scalar.from_u64(1)
two = Scalar.from_u64(2)
five = Scalar.from_u64(5)
ten = Scalar.from_u64(10)

def test_scalar_eq():
assert zero == zero2
assert zero2 == zero
assert zero != one
assert ten != one

def test_sub():
assert one - one == zero
assert one - zero == one
assert ten - five == five

def test_add():
assert one + one == two
assert two + two + one == five
assert zero + zero == zero
assert one + zero == one

def test_scalar_mul():
assert one * one == one
assert one * zero == zero
assert zero * ten == zero
assert one * ten == ten
assert five * two == ten

0 comments on commit 57361ba

Please sign in to comment.