From a8a8f940f0140fb622ca0388ddc489794b0dcef2 Mon Sep 17 00:00:00 2001 From: adamantike Date: Tue, 29 Mar 2016 13:21:58 -0300 Subject: [PATCH] Simplified bitsize calculation --- rsa/common.py | 18 ++++++++---------- tests/test_common.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/rsa/common.py b/rsa/common.py index 34142cc..bbd15e6 100644 --- a/rsa/common.py +++ b/rsa/common.py @@ -16,6 +16,8 @@ """Common functionality shared by several modules.""" +from rsa._compat import is_integer + class NotRelativePrimeError(ValueError): def __init__(self, a, b, d, msg=None): @@ -50,21 +52,17 @@ def bit_size(num): :returns: Returns the number of bits in the integer. """ + # Make sure this is an int and not a float. + if not is_integer(num): + raise TypeError + if num == 0: return 0 if num < 0: num = -num - # Make sure this is an int and not a float. - num & 1 - - hex_num = "%x" % num - return ((len(hex_num) - 1) * 4) + { - '0': 0, '1': 1, '2': 2, '3': 2, - '4': 3, '5': 3, '6': 3, '7': 3, - '8': 4, '9': 4, 'a': 4, 'b': 4, - 'c': 4, 'd': 4, 'e': 4, 'f': 4, - }[hex_num[0]] + binary_num = "{0:b}".format(num) + return len(binary_num) def _bit_size(number): diff --git a/tests/test_common.py b/tests/test_common.py index ef32f61..dc0a154 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -55,6 +55,7 @@ def test_bad_type(self): self.assertRaises(TypeError, byte_size, dict()) self.assertRaises(TypeError, byte_size, "") self.assertRaises(TypeError, byte_size, None) + self.assertRaises(TypeError, byte_size, 0.0) class TestBitSize(unittest.TestCase): @@ -76,6 +77,16 @@ def test_values(self): self.assertEqual(_bit_size((1 << 1024) + 1), 1025) self.assertEqual(_bit_size((1 << 1024) - 1), 1024) + def test_negative_values(self): + self.assertEqual(bit_size(-1023), 10) + self.assertEqual(bit_size(-1024), 11) + self.assertEqual(bit_size(-1025), 11) + self.assertEqual(bit_size(-1 << 1024), 1025) + self.assertEqual(bit_size(-((1 << 1024) + 1)), 1025) + self.assertEqual(bit_size(-((1 << 1024) - 1)), 1024) + + self.assertRaises(ValueError, _bit_size, -1024) + class TestInverse(unittest.TestCase): def test_normal(self):