Skip to content

Commit

Permalink
feat(tariscript): adds ToRistrettoPoint op-code (#4749)
Browse files Browse the repository at this point in the history
Description
---
- Adds ToRistrettoPoint opcodes
- Added scalar stack item

Motivation and Context
---
As per RFC-202
Ref tari-project/rfcs#15
Ref #4742

How Has This Been Tested?
---
Additional unit tests + tests updated
  • Loading branch information
sdbondi committed Oct 3, 2022
1 parent 62384f9 commit 8f872a1
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 67 deletions.
102 changes: 55 additions & 47 deletions infrastructure/tari_script/src/op_codes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use tari_utilities::{hex::Hex, ByteArray, ByteArrayError};
use super::ScriptError;

pub type HashValue = [u8; 32];
pub type ScalarValue = [u8; 32];
pub type Message = [u8; MESSAGE_LENGTH];

const PUBLIC_KEY_LENGTH: usize = 32;
Expand Down Expand Up @@ -116,6 +117,7 @@ pub const OP_CHECK_MULTI_SIG_VERIFY: u8 = 0xaf;
pub const OP_HASH_BLAKE256: u8 = 0xb0;
pub const OP_HASH_SHA256: u8 = 0xb1;
pub const OP_HASH_SHA3: u8 = 0xb2;
pub const OP_TO_RISTRETTO_POINT: u8 = 0xb3;

// Opcode constants: Miscellaneous
pub const OP_RETURN: u8 = 0x60;
Expand Down Expand Up @@ -232,6 +234,9 @@ pub enum Opcode {
/// Identical to CheckMultiSig, except that nothing is pushed to the stack if the m signatures are valid, and the
/// operation fails with VERIFY_FAILED if any of the signatures are invalid.
CheckMultiSigVerify(u8, u8, Vec<RistrettoPublicKey>, Box<Message>),
/// Pops the top element which must be a valid 32-byte scalar or hash and calculates the corresponding Ristretto
/// point, and pushes the result to the stack. Fails with EMPTY_STACK if the stack is empty.
ToRistrettoPoint,

// Miscellaneous
/// Always fails with VERIFY_FAILED.
Expand Down Expand Up @@ -350,6 +355,7 @@ impl Opcode {
let (m, n, keys, msg, end) = Opcode::read_multisig_args(bytes)?;
Ok((CheckMultiSigVerify(m, n, keys, msg), &bytes[end..]))
},
OP_TO_RISTRETTO_POINT => Ok((ToRistrettoPoint, &bytes[1..])),
OP_RETURN => Ok((Return, &bytes[1..])),
OP_IF_THEN => Ok((IfThen, &bytes[1..])),
OP_ELSE => Ok((Else, &bytes[1..])),
Expand Down Expand Up @@ -447,17 +453,18 @@ impl Opcode {
CheckMultiSig(m, n, public_keys, msg) => {
array.extend_from_slice(&[OP_CHECK_MULTI_SIG, *m, *n]);
for public_key in public_keys {
array.extend(public_key.to_vec());
array.extend(public_key.as_bytes());
}
array.extend_from_slice(msg.deref());
},
CheckMultiSigVerify(m, n, public_keys, msg) => {
array.extend_from_slice(&[OP_CHECK_MULTI_SIG_VERIFY, *m, *n]);
for public_key in public_keys {
array.extend(public_key.to_vec());
array.extend(public_key.as_bytes());
}
array.extend_from_slice(msg.deref());
},
ToRistrettoPoint => array.push(OP_TO_RISTRETTO_POINT),
Return => array.push(OP_RETURN),
IfThen => array.push(OP_IF_THEN),
Else => array.push(OP_ELSE),
Expand All @@ -473,70 +480,68 @@ impl fmt::Display for Opcode {
#[allow(clippy::enum_glob_use)]
use Opcode::*;
match self {
CheckHeightVerify(height) => fmt.write_str(&format!("CheckHeightVerify({})", *height)),
CheckHeight(height) => fmt.write_str(&format!("CheckHeight({})", *height)),
CompareHeightVerify => fmt.write_str("CompareHeightVerify"),
CompareHeight => fmt.write_str("CompareHeight"),
Nop => fmt.write_str("Nop"),
PushZero => fmt.write_str("PushZero"),
PushOne => fmt.write_str("PushOne"),
PushHash(h) => fmt.write_str(&format!("PushHash({})", (*h).to_hex())),
PushInt(n) => fmt.write_str(&format!("PushInt({})", *n)),
PushPubKey(h) => fmt.write_str(&format!("PushPubKey({})", (*h).to_hex())),
Drop => fmt.write_str("Drop"),
Dup => fmt.write_str("Dup"),
RevRot => fmt.write_str("RevRot"),
GeZero => fmt.write_str("GeZero"),
GtZero => fmt.write_str("GtZero"),
LeZero => fmt.write_str("LeZero"),
LtZero => fmt.write_str("LtZero"),
Add => fmt.write_str("Add"),
Sub => fmt.write_str("Sub"),
Equal => fmt.write_str("Equal"),
EqualVerify => fmt.write_str("EqualVerify"),
Or(n) => fmt.write_str(&format!("Or({})", *n)),
OrVerify(n) => fmt.write_str(&format!("OrVerify({})", *n)),
HashBlake256 => fmt.write_str("HashBlake256"),
HashSha256 => fmt.write_str("HashSha256"),
HashSha3 => fmt.write_str("HashSha3"),
CheckSig(msg) => fmt.write_str(&format!("CheckSig({})", (*msg).to_hex())),
CheckSigVerify(msg) => fmt.write_str(&format!("CheckSigVerify({})", (*msg).to_hex())),
CheckHeightVerify(height) => write!(fmt, "CheckHeightVerify({})", *height),
CheckHeight(height) => write!(fmt, "CheckHeight({})", *height),
CompareHeightVerify => write!(fmt, "CompareHeightVerify"),
CompareHeight => write!(fmt, "CompareHeight"),
Nop => write!(fmt, "Nop"),
PushZero => write!(fmt, "PushZero"),
PushOne => write!(fmt, "PushOne"),
PushHash(h) => write!(fmt, "PushHash({})", (*h).to_hex()),
PushInt(n) => write!(fmt, "PushInt({})", *n),
PushPubKey(h) => write!(fmt, "PushPubKey({})", (*h).to_hex()),
Drop => write!(fmt, "Drop"),
Dup => write!(fmt, "Dup"),
RevRot => write!(fmt, "RevRot"),
GeZero => write!(fmt, "GeZero"),
GtZero => write!(fmt, "GtZero"),
LeZero => write!(fmt, "LeZero"),
LtZero => write!(fmt, "LtZero"),
Add => write!(fmt, "Add"),
Sub => write!(fmt, "Sub"),
Equal => write!(fmt, "Equal"),
EqualVerify => write!(fmt, "EqualVerify"),
Or(n) => write!(fmt, "Or({})", *n),
OrVerify(n) => write!(fmt, "OrVerify({})", *n),
HashBlake256 => write!(fmt, "HashBlake256"),
HashSha256 => write!(fmt, "HashSha256"),
HashSha3 => write!(fmt, "HashSha3"),
CheckSig(msg) => write!(fmt, "CheckSig({})", (*msg).to_hex()),
CheckSigVerify(msg) => write!(fmt, "CheckSigVerify({})", (*msg).to_hex()),
CheckMultiSig(m, n, public_keys, msg) => {
let keys: Vec<String> = public_keys.iter().map(|p| p.to_hex()).collect();
fmt.write_str(&format!(
write!(
fmt,
"CheckMultiSig({}, {}, [{}], {})",
*m,
*n,
keys.join(", "),
(*msg).to_hex()
))
)
},
CheckMultiSigVerify(m, n, public_keys, msg) => {
let keys: Vec<String> = public_keys.iter().map(|p| p.to_hex()).collect();
fmt.write_str(&format!(
write!(
fmt,
"CheckMultiSigVerify({}, {}, [{}], {})",
*m,
*n,
keys.join(", "),
(*msg).to_hex()
))
)
},
Return => fmt.write_str("Return"),
IfThen => fmt.write_str("IfThen"),
Else => fmt.write_str("Else"),
EndIf => fmt.write_str("EndIf"),
ToRistrettoPoint => write!(fmt, "ToRistrettoPoint"),
Return => write!(fmt, "Return"),
IfThen => write!(fmt, "IfThen"),
Else => write!(fmt, "Else"),
EndIf => write!(fmt, "EndIf"),
}
}
}

#[cfg(test)]
mod test {
use crate::{
op_codes::*,
Opcode,
Opcode::{Dup, PushHash, Return},
ScriptError,
};
use crate::{op_codes::*, Opcode, ScriptError};

#[test]
fn empty_script() {
Expand All @@ -552,9 +557,9 @@ mod test {
let script = [0x60u8, 0x71];
let opcodes = Opcode::parse(&script).unwrap();
let code = opcodes.first().unwrap();
assert_eq!(code, &Return);
assert_eq!(code, &Opcode::Return);
let code = opcodes.get(1).unwrap();
assert_eq!(code, &Dup);
assert_eq!(code, &Opcode::Dup);

let err = Opcode::parse(&[0x7a]).unwrap_err();
assert!(matches!(err, ScriptError::InvalidData));
Expand All @@ -563,7 +568,7 @@ mod test {
#[test]
fn push_hash() {
let (code, b) = Opcode::read_next(b"\x7a/thirty-two~character~hash~val./").unwrap();
assert!(matches!(code, PushHash(v) if &*v == b"/thirty-two~character~hash~val./"));
assert!(matches!(code, Opcode::PushHash(v) if &*v == b"/thirty-two~character~hash~val./"));
assert!(b.is_empty());
}

Expand Down Expand Up @@ -794,6 +799,7 @@ mod test {
test_opcode(OP_HASH_SHA3, &Opcode::HashSha3);
test_opcode(OP_HASH_BLAKE256, &Opcode::HashBlake256);
test_opcode(OP_HASH_SHA256, &Opcode::HashSha256);
test_opcode(OP_TO_RISTRETTO_POINT, &Opcode::ToRistrettoPoint);
test_opcode(OP_IF_THEN, &Opcode::IfThen);
test_opcode(OP_ELSE, &Opcode::Else);
test_opcode(OP_END_IF, &Opcode::EndIf);
Expand Down Expand Up @@ -825,6 +831,7 @@ mod test {
test_opcode(OP_HASH_SHA3, &Opcode::HashSha3);
test_opcode(OP_HASH_BLAKE256, &Opcode::HashBlake256);
test_opcode(OP_HASH_SHA256, &Opcode::HashSha256);
test_opcode(OP_TO_RISTRETTO_POINT, &Opcode::ToRistrettoPoint);
test_opcode(OP_IF_THEN, &Opcode::IfThen);
test_opcode(OP_ELSE, &Opcode::Else);
test_opcode(OP_END_IF, &Opcode::EndIf);
Expand Down Expand Up @@ -856,6 +863,7 @@ mod test {
test_opcode(&Opcode::HashSha3, "HashSha3");
test_opcode(&Opcode::HashBlake256, "HashBlake256");
test_opcode(&Opcode::HashSha256, "HashSha256");
test_opcode(&Opcode::ToRistrettoPoint, "ToRistrettoPoint");
test_opcode(&Opcode::IfThen, "IfThen");
test_opcode(&Opcode::Else, "Else");
test_opcode(&Opcode::EndIf, "EndIf");
Expand Down
43 changes: 43 additions & 0 deletions infrastructure/tari_script/src/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use sha2::Sha256;
use sha3::Sha3_256;
use tari_crypto::{
hash::blake2::Blake256,
keys::PublicKey,
ristretto::{RistrettoPublicKey, RistrettoSchnorr, RistrettoSecretKey},
};
use tari_utilities::{
Expand Down Expand Up @@ -260,6 +261,7 @@ impl TariScript {
Err(ScriptError::VerifyFailed)
}
},
ToRistrettoPoint => self.handle_to_ristretto_point(stack),
Return => Err(ScriptError::Return),
IfThen => TariScript::handle_if_then(stack, state),
Else => TariScript::handle_else(state),
Expand Down Expand Up @@ -537,6 +539,19 @@ impl TariScript {

Ok(sig_set.len() == m)
}

fn handle_to_ristretto_point(&self, stack: &mut ExecutionStack) -> Result<(), ScriptError> {
let item = stack.pop().ok_or(ScriptError::StackUnderflow)?;
let scalar = match &item {
StackItem::Hash(hash) => hash.as_slice(),
StackItem::Scalar(scalar) => scalar.as_slice(),
_ => return Err(ScriptError::IncompatibleTypes),
};
let ristretto_sk = RistrettoSecretKey::from_bytes(scalar).map_err(|_| ScriptError::InvalidData)?;
let ristretto_pk = RistrettoPublicKey::from_secret_key(&ristretto_sk);
stack.push(StackItem::PublicKey(ristretto_pk))?;
Ok(())
}
}

impl Hex for TariScript {
Expand Down Expand Up @@ -1540,4 +1555,32 @@ mod test {
let result = script.execute(&inputs).unwrap_err();
assert_eq!(result, ScriptError::Return);
}

#[test]
fn to_ristretto_point() {
use crate::StackItem::PublicKey;
let mut rng = rand::thread_rng();
let (k_1, p_1) = RistrettoPublicKey::random_keypair(&mut rng);

use crate::Opcode::ToRistrettoPoint;
let ops = vec![ToRistrettoPoint];
let script = TariScript::new(ops);

// Invalid stack type
let inputs = inputs!(RistrettoPublicKey::default());
let err = script.execute(&inputs).unwrap_err();
assert!(matches!(err, ScriptError::IncompatibleTypes));

// scalar
let mut scalar = [0u8; 32];
scalar.copy_from_slice(k_1.as_bytes());
let inputs = inputs!(scalar);
let result = script.execute(&inputs).unwrap();
assert_eq!(result, PublicKey(p_1.clone()));

// hash
let inputs = ExecutionStack::new(vec![Hash(scalar)]);
let result = script.execute(&inputs).unwrap();
assert_eq!(result, PublicKey(p_1));
}
}
Loading

0 comments on commit 8f872a1

Please sign in to comment.