Skip to content

Commit

Permalink
sail: still working on vector AES. See #24,#26
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-marshall committed Aug 18, 2020
1 parent a95e775 commit 115114f
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 52 deletions.
19 changes: 19 additions & 0 deletions sail/riscv_crypto_tests.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

/*
* A collection of very small unit tests to check that internal
* functions used by multiple crypto extension instructions give the correct
* results given known inputs.
*/


val crypto_test_vaes_128_keystep_fwd : unit -> bool
function crypto_test_vaes_128_keystep_fwd () = {
let input : bits(128) = 0x09cf4f3cabf7158828aed2a62b7e1516;
let grm_out1: bits(128) = 0x2A6C760523A3393988542CB1A0FAFE17;
let dut_out1: bits(128) = vaes128_keystep_fwd(input, 0x0);
if(dut_out1 != grm_out1) then false else {
let grm_out2: bits(128) = 0x7359f67f5935807a7a96b943f2c295f2;
let dut_out2: bits(128) = vaes128_keystep_fwd(dut_out1, 0x1);
if(dut_out2 != grm_out2) then false else true
}
}
145 changes: 104 additions & 41 deletions sail/riscv_insts_crypto_rvv_aes.sail
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,22 @@ function vaes_round_inv (last, state , rkey ) = {


/*
*
* Given the current 128-bit round key for AES-128, compute the next
* 128-bit round key.
*/
val vaes128_keystep_fwd : (bits(128), bits(4)) -> bits(128)
function vaes128_keystep_fwd (current , rnd ) = {
current /* TODO */
let wi0 : bits(32) = current[ 31.. 0]; /* Unpack current rnd key words */
let wi1 : bits(32) = current[ 63..32];
let wi2 : bits(32) = current[ 95..64];
let wi3 : bits(32) = current[127..96];
let rcon: bits(32) = aes_decode_rcon(rnd);
let temp: bits(32) = aes_subword(ror32(wi3,to_bits(8,24))) ^ rcon;
let wo0 : bits(32) = wi0 ^ temp;
let wo1 : bits(32) = wi1 ^ wo0;
let wo2 : bits(32) = wi2 ^ wo1;
let wo3 : bits(32) = wi3 ^ wo2;
wo3 @ wo2 @ wo1 @ wo0 /* Return value */
}


Expand All @@ -142,24 +153,48 @@ function vaes128_keystep_fwd (current , rnd ) = {
*/
val vaes192_keystep_fwd : (bits(192), bits(4)) -> bits(128)
function vaes192_keystep_fwd (current , rnd ) = {
current[127..0] /* TODO */
current[127..0] /* TODO: vaes192_keystep_fwd */
}


/*
*
* Given the previous two 128-bit round keys, compute the next
* 128-bit round key.
* TODO: vaes256_keystep_fwd is broken
*/
val vaes256_keystep_fwd : (bits(256), bits(4)) -> bits(128)
function vaes256_keystep_fwd (current , rnd ) = {
current[127..0] /* TODO */
val vaes256_keystep_fwd : (bits(128), bits(128), bits(4)) -> bits(128)
function vaes256_keystep_fwd (prev , current , rnd ) = {
let wi0 : bits(32) = prev [ 31.. 0]; /* Unpack current rnd key words */
let wi1 : bits(32) = prev [ 63.. 32];
let wi2 : bits(32) = prev [ 95.. 64];
let wi3 : bits(32) = prev [127.. 96];
let wi4 : bits(32) = current[ 31.. 0];
let wi5 : bits(32) = current[ 63.. 32];
let wi6 : bits(32) = current[ 95.. 64];
let wi7 : bits(32) = current[127.. 96];
let rcon: bits(32) = aes_decode_rcon(rnd);
let temp: bits(32) = aes_subword(ror32(wi7,to_bits(8,24))) ^ rcon;
let wo0 : bits(32) = wi0 ^ temp;
let wo1 : bits(32) = wi1 ^ wo0;
let wo2 : bits(32) = wi2 ^ wo1;
let wo3 : bits(32) = wi3 ^ wo2;
if(bit_to_bool(rnd[0])) then {
wo3 @ wo2 @ wo1 @ wo0 /* Return round key N */
} else {
let wo4 : bits(32) = wi4 ^ aes_subword(wo3);
let wo5 : bits(32) = wi5 ^ wo4;
let wo6 : bits(32) = wi6 ^ wo5;
let wo7 : bits(32) = wi7 ^ wo6;
wo7 @ wo6 @ wo5 @ wo4 /* Return round key N+1 */
}
}

/*
*
*/
val vaes128_keystep_inv : (bits(128), bits(4)) -> bits(128)
function vaes128_keystep_inv (current , rnd ) = {
current /* TODO */
current /* TODO: vaes128_keystep_inv */
}


Expand All @@ -168,7 +203,7 @@ function vaes128_keystep_inv (current , rnd ) = {
*/
val vaes192_keystep_inv : (bits(192), bits(4)) -> bits(128)
function vaes192_keystep_inv (current , rnd ) = {
current[127..0] /* TODO */
current[127..0] /* TODO: vaes192_keystep_inv */
}


Expand All @@ -177,7 +212,7 @@ function vaes192_keystep_inv (current , rnd ) = {
*/
val vaes256_keystep_inv : (bits(256), bits(4)) -> bits(128)
function vaes256_keystep_inv (current , rnd ) = {
current[127..0] /* TODO */
current[127..0] /* TODO: vaes256_keystep_inv */
}


Expand Down Expand Up @@ -370,17 +405,16 @@ function clause execute ( VAES192INVKEYI (vt,rnd,vs2)) = {

/* AES 256 Single round key schedule - forwards */
function clause execute ( VAES256KEYI (vt,rnd,vs2)) = {
if(vaes_asserts()) then {
foreach (i from 0 to vGetVL()) {
let prev : bits(128) = vGetElement128(vt , i);
let current : bits(128) = vGetElement128(vs2, i);
let keys : bits(256) = current @ prev;
let result : bits(128) = vaes256_keystep_fwd(keys, rnd);
vSetElement128(vt, i, result);
};
RETIRE_SUCCESS
} else
RETIRE_FAIL
if(vaes_asserts()) then {
foreach (i from 0 to vGetVL()) {
let prev : bits(128) = vGetElement128(vt , i);
let current : bits(128) = vGetElement128(vs2, i);
let next : bits(128) = vaes256_keystep_fwd(prev, current, rnd);
vSetElement128(vt, i, next);
};
RETIRE_SUCCESS
} else
RETIRE_FAIL
}

/* AES 256 Single round key schedule - inverse */
Expand Down Expand Up @@ -408,7 +442,7 @@ function clause execute ( VAES256INVKEYI (vt,rnd,vs2)) = {
/* AES 128 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE128_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(128) = vGetElement128(vs1, 0); /* Vector-Scalar */
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -429,7 +463,7 @@ function clause execute ( VAESE128_VS (vt,vs1)) = {
/* AES 192 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE192_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(192) = vGetElement256(vs1, 0)[191..0];
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -450,7 +484,7 @@ function clause execute ( VAESE192_VS (vt,vs1)) = {
/* AES 256 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE256_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(256) = vGetElement256(vs1, 0);
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -476,7 +510,7 @@ function clause execute ( VAESE256_VS (vt,vs1)) = {
/* AES 128 All Rounds Decrypt - Vector-Scalar */
function clause execute ( VAESD128_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(128) = vGetElement128(vs1, 0); /* Vector-Scalar */
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -497,7 +531,7 @@ function clause execute ( VAESD128_VS (vt,vs1)) = {
/* AES 192 All Rounds Decrypt - Vector-Scalar */
function clause execute ( VAESD192_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(192) = vGetElement256(vs1, 0)[191..0];
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -517,7 +551,7 @@ function clause execute ( VAESD192_VS (vt,vs1)) = {
/* AES 256 All Rounds Decrypt - Vector-Scalar */
function clause execute ( VAESD256_VS (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(256) = vGetElement256(vs1, 0);
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -542,7 +576,7 @@ function clause execute ( VAESD256_VS (vt,vs1)) = {
/* AES 128 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE128_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(128) = vGetElement128(vs1, i); /* Vector-Vector */
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -563,7 +597,7 @@ function clause execute ( VAESE128_VV (vt,vs1)) = {
/* AES 192 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE192_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(192) = vGetElement256(vs1, i)[191..0];
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -584,7 +618,7 @@ function clause execute ( VAESE192_VV (vt,vs1)) = {
/* AES 256 All Rounds Encrypt - Vector-Scalar */
function clause execute ( VAESE256_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(256) = vGetElement256(vs1, i);
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -606,12 +640,12 @@ function clause execute ( VAESE256_VV (vt,vs1)) = {
* ------------------------------------------------------------
*/

/* AES 128 All Rounds Decrypt - Vector-Scalar */
/* AES 128 All Rounds Decrypt - Vector-Vector */
function clause execute ( VAESD128_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(128) = vGetElement128(vs1, i); /* Vector-Scalar */
rkey : bits(128) = vGetElement128(vs1, i); /* Vector-Vector */
state : bits(128) = vGetElement128(vt , i);
foreach(i from 1 to 10) {
state = vaes_round_inv(false, state, rkey);
Expand All @@ -627,10 +661,10 @@ function clause execute ( VAESD128_VV (vt,vs1)) = {
}


/* AES 192 All Rounds Decrypt - Vector-Scalar */
/* AES 192 All Rounds Decrypt - Vector-Vector */
function clause execute ( VAESD192_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(192) = vGetElement256(vs1, i)[191..0];
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -647,10 +681,10 @@ function clause execute ( VAESD192_VV (vt,vs1)) = {
RETIRE_FAIL
}

/* AES 256 All Rounds Decrypt - Vector-Scalar */
/* AES 256 All Rounds Decrypt - Vector-Vector */
function clause execute ( VAESD256_VV (vt,vs1)) = {
if(vaes_asserts()) then {
foreach (i from 1 to vGetVL()) {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
ekey : bits(256) = vGetElement256(vs1, i);
state : bits(128) = vGetElement128(vt , i);
Expand All @@ -667,12 +701,27 @@ function clause execute ( VAESD256_VV (vt,vs1)) = {
RETIRE_FAIL
}

/*
* Vector-Vector Get final round key from first round key.
* ------------------------------------------------------------
*/


/* AES 128 Get final decryption key from encryption key. */
function clause execute ( VAES128RKEY (vt)) = {
/* TBD, implemented as nop.*/
RETIRE_SUCCESS
if(vaes_asserts()) then {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(128) = vGetElement128(vt, i);
foreach(i from 1 to 10) {
rkey = vaes128_keystep_fwd(rkey, rnd);
rnd = rnd + 1;
};
vSetElement128(vt, i, rkey);
};
RETIRE_SUCCESS
} else
RETIRE_FAIL
}


Expand All @@ -685,6 +734,20 @@ function clause execute ( VAES192RKEY (vt)) = {

/* AES 256 Get final decryption key from encryption key. */
function clause execute ( VAES256RKEY (vt)) = {
/* TBD, implemented as nop.*/
RETIRE_SUCCESS
if(vaes_asserts()) then {
foreach (i from 0 to vGetVL()) {
rnd : bits(4) = 0x0;
rkey : bits(256) = vGetElement256(vt, i);
prev : bits(128) = rkey[127.. 0];
curr : bits(128) = rkey[255..128];
foreach(i from 1 to 7) { /* Two steps at a time */
prev = vaes256_keystep_fwd(prev, curr, rnd+0);
curr = vaes256_keystep_fwd(curr, prev, rnd+1);
rnd = rnd + 2;
};
vSetElement256(vt, i, curr @ prev);
};
RETIRE_SUCCESS
} else
RETIRE_FAIL
}
39 changes: 28 additions & 11 deletions sail/riscv_types_crypto.sail
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,13 @@ function vGetElement256 (ridx, elem) = {
/* Dummy function to set the i'th 128-bit element of a vector register */
val vSetElement128 : (vregidx, int, bits(128)) -> unit effect {wreg, escape}
function vSetElement128 (ridx, elem, value) = {
assert(true,"TODO")
assert(true,"TODO: vSetElement128")
}

/* Dummy function to set the i'th 256-bit element of a vector register */
val vSetElement256 : (vregidx, int, bits(256)) -> unit effect {wreg, escape}
function vSetElement256 (ridx, elem, value) = {
assert(true,"TODO: vSetElement256")
}


Expand Down Expand Up @@ -240,16 +246,16 @@ function aes_mixcolumn_inv (x) = {
val aes_decode_rcon : bits(4) -> bits(32)
function aes_decode_rcon (r) = {
match r {
0x0 => 0x00000001,
0x1 => 0x00000002,
0x2 => 0x00000004,
0x3 => 0x00000008,
0x4 => 0x00000010,
0x5 => 0x00000020,
0x6 => 0x00000040,
0x7 => 0x00000080,
0x8 => 0x0000001b,
0x9 => 0x00000036,
0x0 => 0x01000000,
0x1 => 0x02000000,
0x2 => 0x04000000,
0x3 => 0x08000000,
0x4 => 0x10000000,
0x5 => 0x20000000,
0x6 => 0x40000000,
0x7 => 0x80000000,
0x8 => 0x1b000000,
0x9 => 0x36000000,
0xA => 0x00000000,
0xB => 0x00000000,
0xC => 0x00000000,
Expand Down Expand Up @@ -357,6 +363,17 @@ function aes_sbox_inv (x) = {
sbox_lookup(x, aes_sbox_inv_table)
}

/* AES SubWord function used in the key expansion
* - Applies the forward sbox to each byte in the input word.
*/
val aes_subword : bits(32) -> bits(32)
function aes_subword (x) = {
aes_sbox_fwd(x[31..24]) @
aes_sbox_fwd(x[23..16]) @
aes_sbox_fwd(x[15.. 8]) @
aes_sbox_fwd(x[ 7.. 0])
}


/* Easy function to perform an SM4 SBox operation on 1 byte. */
val sm4_sbox : bits(8) -> bits(8)
Expand Down

0 comments on commit 115114f

Please sign in to comment.