diff --git a/Level-1/hack.py b/Level-1/hack.py index 731ddc0..a9c7dcb 100644 --- a/Level-1/hack.py +++ b/Level-1/hack.py @@ -11,5 +11,13 @@ def test_4(self): order_4 = c.Order(id='4', items=[payment, tv, reimbursement]) self.assertEqual(c.validorder(order_4), 'Order ID: 4 - Payment imbalance: $-1000.00') + # Valid payments that should add up correctly, but do not + def test_5(self): + small_item = c.Item(type='product', description='accessory', amount=3.3, quantity=1) + payment_1 = c.Item(type='payment', description='invoice_5_1', amount=1.1, quantity=1) + payment_2 = c.Item(type='payment', description='invoice_5_2', amount=2.2, quantity=1) + order_5 = c.Order(id='5', items=[small_item, payment_1, payment_2]) + self.assertEqual(c.validorder(order_5), 'Order ID: 5 - Full payment received!') + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/Level-1/solution.py b/Level-1/solution.py index 82f16f9..0f85bd5 100644 --- a/Level-1/solution.py +++ b/Level-1/solution.py @@ -1,4 +1,5 @@ from collections import namedtuple +from decimal import * Order = namedtuple('Order', 'id, items') Item = namedtuple('Item', 'type, description, amount, quantity') @@ -8,16 +9,16 @@ MAX_TOTAL = 1e6 # maximum total amount accepted for an order def validorder(order): - net = 0 + net = Decimal('0') for item in order.items: if item.type == 'payment': # sets a reasonable min & max value for the invoice amounts if item.amount > -1*MAX_ITEM_AMOUNT and item.amount < MAX_ITEM_AMOUNT: - net += item.amount + net += Decimal(str(item.amount)) elif item.type == 'product': if item.quantity > 0 and item.quantity <= MAX_QUANTITY and item.amount > 0 and item.amount <= MAX_ITEM_AMOUNT: - net -= item.amount * item.quantity + net -= Decimal(str(item.amount)) * item.quantity if net > MAX_TOTAL or net < -1*MAX_TOTAL: return("Total amount exceeded") else: @@ -41,4 +42,20 @@ def validorder(order): We also need to protect from a scenario where the attacker sends a huge number of items, resulting in a huge net. We can do this by limiting all variables to reasonable values. + +In addition, using floating-point data types for calculations involving financial values causes unexpected rounding and comparison +errors as it cannot represent decimal numbers with the precision we expect. +For example, running `0.1 + 0.2` in the Python interpreter gives `0.30000000000000004` instead of 0.3. + +The solution to this is to use the Decimal type for calculations that should work in the same way "as the arithmetic that people learn at school." +-- except from Python's documentation on Decimal (https://docs.python.org/3/library/decimal.html). + +It is also necessary to convert the floating point values to string first before passing it to the Decimal constructor. +If the floating point value is passed to the Decimal constructor, the rounded value is stored instead. + +Compare the following examples from the interpreter: +>>> Decimal(0.3) +Decimal('0.299999999999999988897769753748434595763683319091796875') +>>> Decimal('0.3') +Decimal('0.3') ''' diff --git a/Level-1/tests.py b/Level-1/tests.py index 126e3d6..0e057a6 100644 --- a/Level-1/tests.py +++ b/Level-1/tests.py @@ -25,4 +25,4 @@ def test_3(self): self.assertEqual(c.validorder(order_3), 'Order ID: 3 - Payment imbalance: $-1000.00') if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()