From de4b4d18b8d12c3e006ff536358df7ddd10a7858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Hol=C3=BD?= Date: Mon, 13 May 2024 13:08:09 +0200 Subject: [PATCH] fix captured_amount not being saved when processing data Partially fixes: https://github.com/jazzband/django-payments/issues/309 --- HISTORY.rst | 4 + payments_payu/provider.py | 3 + tests/test_payu.py | 276 +++++++++++++++++++++++++++++++++----- 3 files changed, 246 insertions(+), 37 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index f406589..81e668e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,10 @@ History ------- +Unreleased +********** +* fix captured_amount not being saved when processing data + 1.4.0 (2024-04-12) ****************** * fix backward compatibility by making PayuProvider's get_refund_description argument optional diff --git a/payments_payu/provider.py b/payments_payu/provider.py index d2a3f9b..648f05e 100644 --- a/payments_payu/provider.py +++ b/payments_payu/provider.py @@ -589,6 +589,9 @@ def process_notification(self, payment, request): data["order"]["totalAmount"], data["order"]["currencyCode"], ) + payment.objects.filter(pk=payment.pk).update( + captured_amount=payment.captured_amount + ) payment.change_status(status) return HttpResponse("ok", status=200) return HttpResponse("not ok", status=500) diff --git a/tests/test_payu.py b/tests/test_payu.py index d291b6c..bb6ecfd 100644 --- a/tests/test_payu.py +++ b/tests/test_payu.py @@ -3,6 +3,7 @@ import contextlib import json import warnings +from copy import deepcopy from decimal import Decimal from unittest import TestCase @@ -33,9 +34,60 @@ def __eq__(self, other): return self.json == json.loads(other) +class PaymentQuerySet(Mock): + __payments = {} + + def create(self, **kwargs): + if kwargs: + raise NotImplementedError(f"arguments not supported yet: {kwargs}") + id_ = max(self.__payments) + 1 if self.__payments else 1 + self.__payments[id_] = {} + payment = Payment() + payment.id = id_ + payment.save() + return payment + + def get(self, *args, **kwargs): + if args or kwargs: + return self.filter(*args, **kwargs).get() + payment = Payment() + (payment_fields,) = self.__payments.values() + for payment_field_name, payment_field_value in payment_fields.items(): + setattr(payment, payment_field_name, deepcopy(payment_field_value)) + return payment + + def filter(self, *args, pk=None, **kwargs): + if args or kwargs: + raise NotImplementedError(f"arguments not supported yet: {args}, {kwargs}") + if pk is not None: + return PaymentQuerySet( + {pk_: payment for pk_, payment in self.__payments.items() if pk_ == pk} + ) + return self + + def update(self, **kwargs): + for payment in self.__payments.values(): + for field_name, field_value in kwargs.items(): + if not any( + field.name == field_name + for field in Payment._meta.get_fields( + include_parents=True, include_hidden=True + ) + ): + raise NotImplementedError( + f"updating unknown field not supported yet: {field_name}" + ) + payment[field_name] = deepcopy(field_value) + + def delete(self): + self.__payments.clear() + + class Payment(Mock): UNSET = object() + objects = PaymentQuerySet() + id = 1 description = "payment" currency = "USD" @@ -64,13 +116,20 @@ class Payment(Mock): } ) - def change_fraud_status(self, status, message=""): + @property + def pk(self): + return self.id + + def change_fraud_status(self, status, message="", commit=True): self.fraud_status = status self.message = message + if commit: + self.save() def change_status(self, status, message=""): self.status = status self.message = message + self.save(update_fields=["status", "message"]) def get_failure_url(self): return "http://cancel.com" @@ -110,12 +169,65 @@ def set_renew_token( self.automatic_renewal = automatic_renewal self.renewal_triggered_by = renewal_triggered_by + def save(self, *args, update_fields=None, **kwargs): + if args or kwargs: + raise NotImplementedError(f"arguments not supported yet: {args}, {kwargs}") + if update_fields is None: + update_fields = { + field.name + for field in self._meta.get_fields( + include_parents=True, include_hidden=True + ) + } + Payment.objects.filter(pk=self.pk).update( + **{field: getattr(self, field) for field in update_fields} + ) + + def refresh_from_db(self, *args, **kwargs): + if args or kwargs: + raise NotImplementedError(f"arguments not supported yet: {args}, {kwargs}") + payment_from_db = Payment.objects.get(pk=self.pk) + for field in self._meta.get_fields(include_parents=True, include_hidden=True): + field_value_from_db = getattr(payment_from_db, field.name) + setattr(self, field.name, field_value_from_db) + + class Meta(Mock): + def get_fields(self, include_parents=True, include_hidden=False): + fields = [] + for field_name in { + "id", + "description", + "currency", + "delivery", + "status", + "fraud_status", + "tax", + "total", + "billing_first_name", + "billing_last_name", + "billing_email", + "captured_amount", + "variant", + "transaction_id", + "message", + "customer_ip_address", + "token", + "extra_data", + }: + field = Mock() + field.name = field_name + fields.append(field) + return tuple(fields) + + _meta = Meta() + class TestPayuProvider(TestCase): urls = "myapp.test_urls" def setUp(self): - self.payment = Payment() + Payment.objects.delete() + self.payment = Payment.objects.create() def set_up_provider(self, recurring, express, **kwargs): with patch("requests.post") as mocked_post: @@ -629,6 +741,9 @@ def test_process_notification(self): self.assertEqual(ret_val.content, b"ok") self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_process_notification_cancelled(self): """Test processing PayU cancelled notification""" @@ -663,6 +778,9 @@ def test_process_notification_cancelled(self): self.assertEqual(ret_val.content, b"ok") self.assertEqual(self.payment.status, PaymentStatus.REJECTED) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.REJECTED) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_process_notification_refund(self): """Test processing PayU refund notification""" @@ -699,6 +817,10 @@ def test_process_notification_refund(self): self.assertEqual(self.payment.status, PaymentStatus.REFUNDED) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.REFUNDED) + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) def test_process_notification_partial_refund(self): """Test processing PayU partial refund notification""" @@ -734,6 +856,9 @@ def test_process_notification_partial_refund(self): self.assertEqual(ret_val.__class__.__name__, "HttpResponse") self.assertEqual(ret_val.status_code, 200) self.assertEqual(ret_val.content, b"ok") + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal("110")) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.payment.refresh_from_db() self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal("110")) @@ -792,6 +917,9 @@ def test_process_notification_total_amount(self): self.assertEqual(ret_val.content, b"ok") self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual(self.payment.captured_amount, Decimal("2")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual(self.payment.captured_amount, Decimal("2")) def test_process_notification_error(self): """Test processing PayU notification with wrong signature""" @@ -812,6 +940,9 @@ def test_process_notification_error(self): self.assertEqual(ret_val.content, b"not ok") self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_process_notification_error_malformed_post(self): """Test processing PayU notification with malformed POST""" @@ -885,6 +1016,9 @@ def test_process_first_renew(self): ) self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_process_renew(self): """Test processing renew""" @@ -949,6 +1083,9 @@ def test_process_renew(self): ) self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_process_renew_card_on_file(self): """Test processing renew""" @@ -1014,6 +1151,9 @@ def test_process_renew_card_on_file(self): ) self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_auto_complete_recurring(self): """Test processing renew. The function should return 'success' string, if nothing is required from user.""" @@ -1029,6 +1169,9 @@ def test_auto_complete_recurring(self): self.assertEqual(redirect, "success") self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_auto_complete_recurring_cvv2(self): """Test processing renew when cvv2 form is required - it should return the payment processing URL""" @@ -1050,6 +1193,9 @@ def test_auto_complete_recurring_cvv2(self): self.assertEqual(redirect, "https://example.com/payment/token") self.assertEqual(self.payment.status, PaymentStatus.WAITING) self.assertEqual(self.payment.captured_amount, Decimal("0")) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.assertEqual(self.payment.captured_amount, Decimal("0")) def test_delete_card_token(self): """Test delete_card_token()""" @@ -1116,6 +1262,8 @@ def test_reject_order(self): }, ) self.assertEqual(self.payment.status, PaymentStatus.REJECTED) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.REJECTED) def test_reject_order_error(self): """Test processing renew""" @@ -1138,6 +1286,8 @@ def test_reject_order_error(self): }, ) self.assertEqual(self.payment.status, PaymentStatus.WAITING) + self.payment.refresh_from_db() + self.assertEqual(self.payment.status, PaymentStatus.WAITING) def test_refund(self): with warnings.catch_warnings(record=True) as caught_warnings: @@ -1203,16 +1353,21 @@ def test_refund(self): with refund_request_patch as refund_request_mock: amount = self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(amount, Decimal(110)) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(210)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, + json.loads(self.payment.extra_data)["refund_responses"], + [payment_extra_data_refund_response_previous, refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(210)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], [payment_extra_data_refund_response_previous, refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1261,16 +1416,22 @@ def test_refund_no_amount(self): with refund_request_patch as refund_request_mock: amount = self.provider.refund(self.payment) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(amount, Decimal(220)) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1320,16 +1481,22 @@ def test_refund_no_get_refund_ext_id(self): ): amount = self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(amount, Decimal(110)) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1376,16 +1543,22 @@ def test_refund_no_ext_id(self): with refund_request_patch as refund_request_mock: amount = self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(amount, Decimal(110)) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1433,9 +1606,6 @@ def test_refund_no_ext_id_twice(self): amount1 = self.provider.refund(self.payment, Decimal(200)) amount2 = self.provider.refund(self.payment, Decimal(200)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 2) self.assertEqual(amount2, amount1) self.assertEqual(amount2, Decimal(200)) @@ -1443,7 +1613,15 @@ def test_refund_no_ext_id_twice(self): self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body, refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], [refund_request_response_body, refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1492,16 +1670,22 @@ def test_refund_finalized(self): with refund_request_patch as refund_request_mock: amount = self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(amount, Decimal(110)) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1552,15 +1736,21 @@ def test_refund_canceled(self): with refund_request_patch as refund_request_mock: self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1607,15 +1797,21 @@ def test_refund_error(self): with refund_request_patch as refund_request_mock: self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data)[ - "refund_responses" - ] self.assertEqual(refund_request_mock.call_count, 1) self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) self.assertEqual( - payment_extra_data_refund_responses, [refund_request_response_body] + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], + ) + self.payment.refresh_from_db() + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertEqual( + json.loads(self.payment.extra_data)["refund_responses"], + [refund_request_response_body], ) self.assertFalse(caught_warnings) @@ -1635,13 +1831,19 @@ def test_refund_no_get_refund_description(self): with self.assertRaisesRegex(ValueError, r"^get_refund_description not set"): self.provider.refund(self.payment, Decimal(110)) - payment_extra_data_refund_responses = json.loads(self.payment.extra_data).get( - "refund_responses", [] + self.assertEqual(self.payment.total, Decimal(220)) + self.assertEqual(self.payment.captured_amount, Decimal(220)) + self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) + self.assertFalse( + json.loads(self.payment.extra_data).get("refund_responses", []) ) + self.payment.refresh_from_db() self.assertEqual(self.payment.total, Decimal(220)) self.assertEqual(self.payment.captured_amount, Decimal(220)) self.assertEqual(self.payment.status, PaymentStatus.CONFIRMED) - self.assertFalse(payment_extra_data_refund_responses) + self.assertFalse( + json.loads(self.payment.extra_data).get("refund_responses", []) + ) self.assertEqual(len(caught_warnings), 1) self.assertTrue(issubclass(caught_warnings[0].category, DeprecationWarning)) self.assertEqual(