11use crate :: function:: OptionalArg ;
2+ use crate :: obj:: objbyteinner:: PyBytesLike ;
3+ use crate :: obj:: objstr:: PyStringRef ;
24use crate :: obj:: { objiter, objtype} ;
3- use crate :: pyobject:: { PyObjectRef , PyResult , TypeProtocol } ;
5+ use crate :: pyobject:: { Either , PyObjectRef , PyResult , TypeProtocol } ;
46use crate :: VirtualMachine ;
7+ use volatile:: Volatile ;
58
69fn operator_length_hint ( obj : PyObjectRef , default : OptionalArg , vm : & VirtualMachine ) -> PyResult {
710 let default = default. unwrap_or_else ( || vm. new_int ( 0 ) ) ;
@@ -17,8 +20,82 @@ fn operator_length_hint(obj: PyObjectRef, default: OptionalArg, vm: &VirtualMach
1720 Ok ( hint)
1821}
1922
23+ #[ inline( never) ]
24+ #[ cold]
25+ fn timing_safe_cmp ( a : & [ u8 ] , b : & [ u8 ] ) -> bool {
26+ // we use raw pointers here to keep faithful to the C implementation and
27+ // to try to avoid any optimizations rustc might do with slices
28+ let len_a = a. len ( ) ;
29+ let a = a. as_ptr ( ) ;
30+ let len_b = b. len ( ) ;
31+ let b = b. as_ptr ( ) ;
32+ /* The volatile type declarations make sure that the compiler has no
33+ * chance to optimize and fold the code in any way that may change
34+ * the timing.
35+ */
36+ let length: Volatile < usize > ;
37+ let mut left: Volatile < * const u8 > ;
38+ let mut right: Volatile < * const u8 > ;
39+ let mut result: u8 = 0 ;
40+
41+ /* loop count depends on length of b */
42+ length = Volatile :: new ( len_b) ;
43+ left = Volatile :: new ( std:: ptr:: null ( ) ) ;
44+ right = Volatile :: new ( b) ;
45+
46+ /* don't use else here to keep the amount of CPU instructions constant,
47+ * volatile forces re-evaluation
48+ * */
49+ if len_a == length. read ( ) {
50+ left. write ( Volatile :: new ( a) . read ( ) ) ;
51+ result = 0 ;
52+ }
53+ if len_a != length. read ( ) {
54+ left. write ( b) ;
55+ result = 1 ;
56+ }
57+
58+ for _ in 0 ..length. read ( ) {
59+ let l = left. read ( ) ;
60+ left. write ( l. wrapping_add ( 1 ) ) ;
61+ let r = right. read ( ) ;
62+ right. write ( r. wrapping_add ( 1 ) ) ;
63+ // safety: the 0..length range will always be either:
64+ // * as long as the length of both a and b, if len_a and len_b are equal
65+ // * as long as b, and both `left` and `right` are b
66+ result |= unsafe { l. read_volatile ( ) ^ r. read_volatile ( ) } ;
67+ }
68+
69+ result == 0
70+ }
71+
72+ fn operator_compare_digest (
73+ a : Either < PyStringRef , PyBytesLike > ,
74+ b : Either < PyStringRef , PyBytesLike > ,
75+ vm : & VirtualMachine ,
76+ ) -> PyResult < bool > {
77+ let res = match ( a, b) {
78+ ( Either :: A ( a) , Either :: A ( b) ) => {
79+ if !a. as_str ( ) . is_ascii ( ) || !b. as_str ( ) . is_ascii ( ) {
80+ return Err ( vm. new_type_error (
81+ "comparing strings with non-ASCII characters is not supported" . to_string ( ) ,
82+ ) ) ;
83+ }
84+ timing_safe_cmp ( a. as_str ( ) . as_bytes ( ) , b. as_str ( ) . as_bytes ( ) )
85+ }
86+ ( Either :: B ( a) , Either :: B ( b) ) => a. with_ref ( |a| b. with_ref ( |b| timing_safe_cmp ( a, b) ) ) ,
87+ _ => {
88+ return Err ( vm. new_type_error (
89+ "unsupported operand types(s) or combination of types" . to_string ( ) ,
90+ ) )
91+ }
92+ } ;
93+ Ok ( res)
94+ }
95+
2096pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
2197 py_module ! ( vm, "_operator" , {
2298 "length_hint" => vm. ctx. new_function( operator_length_hint) ,
99+ "_compare_digest" => vm. ctx. new_function( operator_compare_digest) ,
23100 } )
24101}
0 commit comments