-
Notifications
You must be signed in to change notification settings - Fork 79
/
openssl_aes.py
147 lines (124 loc) · 4.86 KB
/
openssl_aes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Author: Trevor Perrin
# See the LICENSE file for legal information regarding use of this file.
"""OpenSSL/M2Crypto AES implementation."""
from .cryptomath import *
from .aes import *
from .python_aes import Python_AES, Python_AES_CTR
if m2cryptoLoaded:
def check_cipher_support(mode):
"""
Check if the cipher is supported by m2crypto before using it,
and if it is not, fallback to the python implementation.
Checking one key size is enough, if one is supported, they are
all supported.
"""
if mode == 2:
if hasattr(m2, 'aes_192_cbc'):
return OpenSSL_AES
return Python_AES
else:
assert mode == 6
if hasattr(m2, 'aes_192_ctr'):
return OpenSSL_CTR
return Python_AES_CTR
def new(key, mode, IV):
# IV argument name is a part of the interface
# pylint: disable=invalid-name
if mode == 2:
impl = check_cipher_support(mode)
return impl(key, mode, IV)
elif mode == 6:
impl = check_cipher_support(mode)
return impl(key, mode, IV)
else:
raise NotImplementedError()
class OpenSSL_AES(AES):
def __init__(self, key, mode, IV):
# IV argument/field names are a part of the interface
# pylint: disable=invalid-name
AES.__init__(self, key, mode, IV, "openssl")
self._IV, self._key = IV, key
self._context = None
self._encrypt = None
@property
def IV(self):
return self._IV
@IV.setter
def IV(self, iv):
if self._context is not None:
m2.cipher_ctx_free(self._context)
self._IV = iv
self._init_context()
def _init_context(self, encrypt=True):
if len(self._key) == 16:
cipherType = m2.aes_128_cbc()
if len(self._key) == 24:
cipherType = m2.aes_192_cbc()
if len(self._key) == 32:
cipherType = m2.aes_256_cbc()
self._context = m2.cipher_ctx_new()
m2.cipher_init(self._context, cipherType, self._key, self._IV,
int(encrypt))
m2.cipher_set_padding(self._context, 0)
self._encrypt = encrypt
def encrypt(self, plaintext):
if self._context is None:
self._init_context(encrypt=True)
else:
assert self._encrypt, '.encrypt() not allowed after .decrypt()'
AES.encrypt(self, plaintext)
ciphertext = m2.cipher_update(self._context, plaintext)
return bytearray(ciphertext)
def decrypt(self, ciphertext):
if self._context is None:
self._init_context(encrypt=False)
else:
assert not self._encrypt, \
'.decrypt() not allowed after .encrypt()'
AES.decrypt(self, ciphertext)
plaintext = m2.cipher_update(self._context, ciphertext)
return bytearray(plaintext)
def __del__(self):
if self._context is not None:
m2.cipher_ctx_free(self._context)
class OpenSSL_CTR(AES):
def __init__(self, key, mode, IV):
# IV argument/field names are a part of the interface
# pylint: disable=invalid-name
AES.__init__(self, key, mode, IV, "openssl")
self._IV = IV
self.key = key
self._context = None
self._encrypt = None
if len(key) not in (16, 24, 32):
raise AssertionError()
@property
def counter(self):
return self._IV
@counter.setter
def counter(self, ctr):
if self._context is not None:
m2.cipher_ctx_free(self._context)
self._IV = ctr
self._init_context()
def _init_context(self, encrypt=True):
if len(self.key) == 16:
cipherType = m2.aes_128_ctr()
if len(self.key) == 24:
cipherType = m2.aes_192_ctr()
if len(self.key) == 32:
cipherType = m2.aes_256_ctr()
self._context = m2.cipher_ctx_new()
m2.cipher_init(self._context, cipherType, self.key, self._IV,
int(encrypt))
m2.cipher_set_padding(self._context, 0)
self._encrypt = encrypt
def encrypt(self, plaintext):
ciphertext = m2.cipher_update(self._context, plaintext)
return bytearray(ciphertext)
def decrypt(self, ciphertext):
plaintext = m2.cipher_update(self._context, ciphertext)
return bytearray(plaintext)
def __del__(self):
if self._context is not None:
m2.cipher_ctx_free(self._context)