From 234d909609ea128f77df28ca419b1abccc789384 Mon Sep 17 00:00:00 2001 From: Piotr Zabieglik <55899043+zedzior@users.noreply.github.com> Date: Tue, 7 May 2024 08:59:14 +0200 Subject: [PATCH 1/5] Optimise variant.stocks query. (#15897) --- .../product_variant_stocks_create.py | 8 +- .../product_variant_stocks_delete.py | 8 +- .../product_variant_stocks_update.py | 8 +- .../queries/test_product_variant_query.py | 36 +++ saleor/graphql/product/types/products.py | 13 +- saleor/graphql/shipping/dataloaders.py | 22 +- saleor/graphql/warehouse/dataloaders.py | 258 +++++++++++++----- 7 files changed, 264 insertions(+), 89 deletions(-) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py index bb0e8857cf1..828548940aa 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py @@ -15,9 +15,7 @@ from ...core.mutations import BaseMutation from ...core.types import BulkStockError, NonNullList from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..mutations.product.product_create import StockInput from ..types import ProductVariant @@ -70,9 +68,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): manager.product_variant_back_in_stock, stock, webhooks=webhooks ) - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py index 93ab24e7567..9a14165fc33 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py @@ -14,9 +14,7 @@ from ...core.types import NonNullList, StockError from ...core.validators import validate_one_of_args_is_in_mutation from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..types import ProductVariant @@ -85,9 +83,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): stocks_to_delete.delete() - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py index b4d00d92752..bc6a562146e 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py @@ -16,9 +16,7 @@ from ...core.types import BulkStockError, NonNullList from ...core.validators import validate_one_of_args_is_in_mutation from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..mutations.product.product_create import StockInput from ..types import ProductVariant @@ -82,9 +80,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): manager = get_plugin_manager_promise(info.context).get() cls.update_or_create_variant_stocks(variant, stocks, warehouses, manager) - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/tests/queries/test_product_variant_query.py b/saleor/graphql/product/tests/queries/test_product_variant_query.py index c9da985ba37..fd77cb3cc72 100644 --- a/saleor/graphql/product/tests/queries/test_product_variant_query.py +++ b/saleor/graphql/product/tests/queries/test_product_variant_query.py @@ -3,6 +3,7 @@ from measurement.measures import Weight from .....core.units import WeightUnits +from .....warehouse import WarehouseClickAndCollectOption from ....core.enums import WeightUnitsEnum from ....tests.utils import assert_no_permission, get_graphql_content @@ -167,6 +168,41 @@ def test_fetch_variant_no_stocks( ) +def test_fetch_variant_stocks_from_click_and_collect_warehouse( + staff_api_client, + product, + permission_manage_products, + channel_USD, +): + # given + query = QUERY_VARIANT + variant = product.variants.first() + stocks_count = variant.stocks.count() + warehouse = variant.stocks.first().warehouse + + # remove the warehouse shipping zones and mark it as click and collect + # the stocks for this warehouse should be still returned + warehouse.shipping_zones.clear() + warehouse.click_and_collect_option = WarehouseClickAndCollectOption.LOCAL_STOCK + warehouse.save(update_fields=["click_and_collect_option"]) + + variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) + variables = {"id": variant_id, "countryCode": "EU", "channel": channel_USD.slug} + staff_api_client.user.user_permissions.add(permission_manage_products) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["productVariant"] + assert data["name"] == variant.name + assert data["created"] == variant.created_at.isoformat() + + assert len(data["stocksByAddress"]) == stocks_count + assert not data["deprecatedStocksByCountry"] + + QUERY_PRODUCT_VARIANT_CHANNEL_LISTING = """ query ProductVariantDetails($id: ID!, $channel: String) { productVariant(id: $id, channel: $channel) { diff --git a/saleor/graphql/product/types/products.py b/saleor/graphql/product/types/products.py index 2194de2cec5..f3f8b2da55b 100644 --- a/saleor/graphql/product/types/products.py +++ b/saleor/graphql/product/types/products.py @@ -117,6 +117,7 @@ from ...warehouse.dataloaders import ( AvailableQuantityByProductVariantIdCountryCodeAndChannelSlugLoader, PreorderQuantityReservedByVariantChannelListingIdLoader, + StocksByProductVariantIdLoader, StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, ) from ...warehouse.types import Stock @@ -434,9 +435,15 @@ def resolve_stocks( ): if address is not None: country_code = address.country - return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).load((root.node.id, country_code, root.channel_slug)) + channle_slug = root.channel_slug + if channle_slug or country_code: + return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( # noqa: E501 + info.context + ).load( + (root.node.id, country_code, root.channel_slug) + ) + else: + return StocksByProductVariantIdLoader(info.context).load(root.node.id) @staticmethod @load_site_callback diff --git a/saleor/graphql/shipping/dataloaders.py b/saleor/graphql/shipping/dataloaders.py index 71edf0a4bd0..e90893174db 100644 --- a/saleor/graphql/shipping/dataloaders.py +++ b/saleor/graphql/shipping/dataloaders.py @@ -1,6 +1,6 @@ from collections import defaultdict -from django.db.models import Exists, F, OuterRef +from django.db.models import Exists, F, OuterRef, Q from ...channel.models import Channel from ...shipping.models import ( @@ -218,3 +218,23 @@ def map_channels(channels): .load_many({pk for pk, _ in channel_and_zone_is_pairs}) .then(map_channels) ) + + +class ShippingZonesByCountryLoader(DataLoader): + context_key = "shippingzones_by_country" + + def batch_load(self, keys): + lookup = Q() + for key in keys: + lookup |= Q(countries__contains=key) + shipping_zones = ShippingZone.objects.using( + self.database_connection_name + ).filter(lookup) + + shipping_zones_by_country = defaultdict(list) + for shipping_zone in shipping_zones: + for country_code in keys: + if country_code in shipping_zone.countries: + shipping_zones_by_country[country_code].append(shipping_zone) + + return [shipping_zones_by_country[key] for key in keys] diff --git a/saleor/graphql/warehouse/dataloaders.py b/saleor/graphql/warehouse/dataloaders.py index 9180e3701d0..000898051c8 100644 --- a/saleor/graphql/warehouse/dataloaders.py +++ b/saleor/graphql/warehouse/dataloaders.py @@ -19,6 +19,7 @@ from django.db.models.functions import Coalesce from django.utils import timezone from django_stubs_ext import WithAnnotations +from promise import Promise from ...channel.models import Channel from ...product.models import ProductVariantChannelListing @@ -32,7 +33,12 @@ Warehouse, ) from ...warehouse.reservations import is_reservation_enabled +from ..channel.dataloaders import ChannelBySlugLoader from ..core.dataloaders import DataLoader +from ..shipping.dataloaders import ( + ShippingZonesByChannelIdLoader, + ShippingZonesByCountryLoader, +) from ..site.dataloaders import get_site_promise if TYPE_CHECKING: @@ -388,79 +394,183 @@ class StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( context_key = "stocks_with_available_quantity_by_productvariant_country_and_channel" def batch_load(self, keys): - # Split the list of keys by country first. A typical query will only touch - # a handful of unique countries but may access thousands of product variants - # so it's cheaper to execute one query per country. - variants_by_country_and_channel: DefaultDict[ - Tuple[CountryCode, str], List[int] - ] = defaultdict(list) - for variant_id, country_code, channel_slug in keys: - variants_by_country_and_channel[(country_code, channel_slug)].append( - variant_id - ) - - # For each country code execute a single query for all product variants. - stocks_by_variant_and_country: DefaultDict[ - VariantIdCountryCodeChannelSlug, List[Stock] - ] = defaultdict(list) - for key, variant_ids in variants_by_country_and_channel.items(): - country_code, channel_slug = key - variant_ids_stocks = self.batch_load_stocks_by_country( - country_code, channel_slug, variant_ids - ) - for variant_id, stocks in variant_ids_stocks: - stocks_by_variant_and_country[ - (variant_id, country_code, channel_slug) - ].extend(stocks) + def with_channels(channels): + def with_shipping_zones(data): + def with_warehouses(warehouse_data): + warehouses_by_channel, warehouses_by_zone = warehouse_data + + # build maps + variant_ids_by_country_and_channel_map: defaultdict[ + tuple[CountryCode, str], list[int] + ] = defaultdict(list) + for variant_id, country_code, channel_slug in keys: + variant_ids_by_country_and_channel_map[ + (country_code, channel_slug) + ].append(variant_id) + + shipping_zones_by_channel_map = { + channel.slug: set(shipping_zones) + for shipping_zones, channel in zip( + shipping_zones_by_channel, channels + ) + } + shipping_zones_by_country_map = { + country_code: set(shipping_zones) + for shipping_zones, country_code in zip( + shipping_zones_by_country, country_codes + ) + } + warehouses_by_channel_map = { + channel.slug: set(warehouses) + for warehouses, channel in zip(warehouses_by_channel, channels) + } + warehouses_by_zone_map = { + shipping_zone_id: set(warehouses) + for warehouses, shipping_zone_id in zip( + warehouses_by_zone, shipping_zone_ids + ) + } + + # filter warehouses + warehouse_ids_by_country_and_channel_map = ( + self.get_relevant_warehouses( + variant_ids_by_country_and_channel_map, + shipping_zones_by_channel_map, + shipping_zones_by_country_map, + warehouses_by_channel_map, + warehouses_by_zone_map, + ) + ) - return [stocks_by_variant_and_country[key] for key in keys] + variant_ids = list(set(key[0] for key in keys)) + warehouse_ids = { + warehouse_id + for warehouse_ids in warehouse_ids_by_country_and_channel_map.values() # noqa: E501 + for warehouse_id in warehouse_ids + } + stocks_qs = Stock.objects.using( + self.database_connection_name + ).filter( + product_variant_id__in=variant_ids, + warehouse_id__in=warehouse_ids, + ) - def batch_load_stocks_by_country( - self, - country_code: Optional[CountryCode], - channel_slug: Optional[str], - variant_ids: Iterable[int], - ) -> Iterable[Tuple[int, List[Stock]]]: - # convert to set to not return the same stocks for the same variant twice - variant_ids_set = set(variant_ids) - stocks = ( - Stock.objects.all() - .using(self.database_connection_name) - .filter(product_variant_id__in=variant_ids_set) - ) - if country_code: - stocks = stocks.filter( - warehouse__shipping_zones__countries__contains=country_code - ) - if channel_slug: - # click and collect warehouses don't have to be assigned to the shipping - # zones, the others must - stocks = stocks.filter( - Q( - warehouse__shipping_zones__channels__slug=channel_slug, - warehouse__channels__slug=channel_slug, + stocks_qs = stocks_qs.annotate_available_quantity().order_by("pk") + + results = [] + for variant_id, country_code, channel_slug in keys: + warehouse_ids = warehouse_ids_by_country_and_channel_map[ + (country_code, channel_slug) + ] + stocks = [ + stock + for stock in stocks_qs + if stock.product_variant_id == variant_id + and stock.warehouse_id in warehouse_ids + ] + results.append(stocks) + + return results + + shipping_zones_by_channel, shipping_zones_by_country = data + channel_ids = [channel.id for channel in channels] + + shipping_zone_by_channel_ids = { + shipping_zone.id + for shipping_zones in shipping_zones_by_channel + for shipping_zone in shipping_zones + } + shipping_zone_by_country_ids = { + shipping_zone.id + for shipping_zones in shipping_zones_by_country + for shipping_zone in shipping_zones + } + shipping_zone_ids = ( + shipping_zone_by_channel_ids | shipping_zone_by_country_ids ) - | Q( - warehouse__channels__slug=channel_slug, - warehouse__click_and_collect_option__in=[ - WarehouseClickAndCollectOption.LOCAL_STOCK, - WarehouseClickAndCollectOption.ALL_WAREHOUSES, - ], + + warehouses_by_channel = WarehousesByChannelIdLoader( + self.context + ).load_many(channel_ids) + warehouses_by_zone = WarehousesByShippingZoneIdLoader( + self.context + ).load_many(shipping_zone_ids) + return Promise.all([warehouses_by_channel, warehouses_by_zone]).then( + with_warehouses ) - ) - stocks = stocks.annotate_available_quantity() - stocks_by_variant_id_map: DefaultDict[int, List[Stock]] = defaultdict(list) - for stock in stocks: - stocks_by_variant_id_map[stock.product_variant_id].append(stock) + channel_ids = [channel.id for channel in set(channels)] + shipping_zones_by_channel = ShippingZonesByChannelIdLoader( + self.context + ).load_many(channel_ids) - return [ - ( - variant_id, - stocks_by_variant_id_map[variant_id], - ) - for variant_id in variant_ids_set - ] + country_codes = list(set(key[1] for key in keys if key[1])) + shipping_zones_by_country = ShippingZonesByCountryLoader( + self.context + ).load_many(country_codes) + + return Promise.all( + [shipping_zones_by_channel, shipping_zones_by_country] + ).then(with_shipping_zones) + + channel_slugs = list(set(key[2] for key in keys if key[2])) + return ( + ChannelBySlugLoader(self.context) + .load_many(channel_slugs) + .then(with_channels) + ) + + @staticmethod + def get_relevant_warehouses( + variant_ids_by_country_and_channel_map, + shipping_zones_by_channel_map, + shipping_zones_by_country_map, + warehouses_by_channel_map, + warehouses_by_zone_map, + ): + warehouse_ids_by_country_and_channel_map = defaultdict(list) + for ( + country_code, + channel_slug, + ), variant_ids in variant_ids_by_country_and_channel_map.items(): + warehouses = set() + warehouses_in_country = set() + # get warehouses from shipping zones in specific country + if country_code: + shipping_zones_in_country = shipping_zones_by_country_map[country_code] + for zone in shipping_zones_in_country: + warehouses_in_country |= warehouses_by_zone_map[zone.id] + + if channel_slug: + warehouses_in_channel = warehouses_by_channel_map[channel_slug] + shipping_zones_in_channel = shipping_zones_by_channel_map[channel_slug] + cc_options = [ + WarehouseClickAndCollectOption.LOCAL_STOCK, + WarehouseClickAndCollectOption.ALL_WAREHOUSES, + ] + # get click & collect warehouses available in channel + cc_warehouses_in_channel = { + warehouse + for warehouse in warehouses_in_channel + if warehouse.click_and_collect_option in cc_options + } + + # get warehouses with shipping zone, both available in channel + warehouses_with_zone_in_channel = set() + for zone in shipping_zones_in_channel: + warehouses_with_zone_in_channel |= ( + warehouses_by_zone_map[zone.id] & warehouses_in_channel + ) + + warehouses = cc_warehouses_in_channel | warehouses_with_zone_in_channel + if country_code: + warehouses &= warehouses_in_country + + warehouse_ids_by_country_and_channel_map[(country_code, channel_slug)] = [ + warehouse.id for warehouse in warehouses + ] + + return warehouse_ids_by_country_and_channel_map class StocksReservationsByCheckoutTokenLoader(DataLoader): @@ -623,3 +733,17 @@ def map_warehouses(warehouses): .load_many({pk for pk, _ in warehouse_and_shipping_zone_in_pairs}) .then(map_warehouses) ) + + +class StocksByProductVariantIdLoader(DataLoader): + context_key = "stocks_by_product_variant" + + def batch_load(self, keys): + stocks = Stock.objects.using(self.database_connection_name).filter( + product_variant_id__in=keys + ) + stocks_by_variant_id = defaultdict(list) + for stock in stocks: + stocks_by_variant_id[stock.product_variant_id].append(stock) + + return [stocks_by_variant_id[key] for key in keys] From 7e25fffa7fbeb9f2c61da4258922da8d8b177fd0 Mon Sep 17 00:00:00 2001 From: zedzior Date: Tue, 7 May 2024 09:10:26 +0200 Subject: [PATCH 2/5] Release 3.16.46 --- package-lock.json | 4 ++-- package.json | 2 +- pyproject.toml | 2 +- saleor/__init__.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/package-lock.json b/package-lock.json index ddf7da3020a..4ca4966b025 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "saleor", - "version": "3.16.45", + "version": "3.16.46", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "saleor", - "version": "3.16.45", + "version": "3.16.46", "license": "BSD-3-Clause", "devDependencies": { "@release-it/bumper": "^4.0.0", diff --git a/package.json b/package.json index 84e3a9e0b1f..c848299dd65 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "saleor", - "version": "3.16.45", + "version": "3.16.46", "engines": { "node": ">=16 <17", "npm": ">=7" diff --git a/pyproject.toml b/pyproject.toml index ef4d9eddcf3..2f2b6336f6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "saleor" -version = "3.16.45" +version = "3.16.46" description = "A modular, high performance, headless e-commerce platform built with Python, GraphQL, Django, and React." authors = [ "Saleor Commerce " ] license = "BSD-3-Clause" diff --git a/saleor/__init__.py b/saleor/__init__.py index 0c736f08d02..a074d8340a4 100644 --- a/saleor/__init__.py +++ b/saleor/__init__.py @@ -3,7 +3,7 @@ from .celeryconf import app as celery_app __all__ = ["celery_app"] -__version__ = "3.16.45" +__version__ = "3.16.46" class PatchedSubscriberExecutionContext(object): From 0a814ef6c6bcb258e4279a93a9899caa698cb007 Mon Sep 17 00:00:00 2001 From: Iga Karbowiak <40886528+IKarbowiak@users.noreply.github.com> Date: Wed, 8 May 2024 09:40:39 +0200 Subject: [PATCH 3/5] Save Atobarai gateway response in transaction object (#15925) --- saleor/payment/gateways/np_atobarai/api.py | 43 +++--- .../gateways/np_atobarai/api_helpers.py | 7 +- .../payment/gateways/np_atobarai/api_types.py | 13 +- .../gateways/np_atobarai/tests/test_api.py | 131 ++++++++++++------ 4 files changed, 124 insertions(+), 70 deletions(-) diff --git a/saleor/payment/gateways/np_atobarai/api.py b/saleor/payment/gateways/np_atobarai/api.py index 5f06a8fb013..60dd10be187 100644 --- a/saleor/payment/gateways/np_atobarai/api.py +++ b/saleor/payment/gateways/np_atobarai/api.py @@ -44,11 +44,11 @@ def register_transaction( reason for pending is returned as error message. """ action = TRANSACTION_REGISTRATION - result, error_codes = register(config, payment_information) + result, error_codes, raw_response = register(config, payment_information) if error_codes: error_messages = get_error_messages_from_codes(action, error_codes=error_codes) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, result) status = result["authori_result"] transaction_id = result["np_transaction_id"] @@ -67,6 +67,7 @@ def register_transaction( status=status, psp_reference=transaction_id, errors=error_messages, + raw_response=raw_response, ) @@ -81,13 +82,13 @@ def cancel_transaction( add_action_to_code(action, error_code=NO_PSP_REFERENCE) ) - result, error_codes = cancel(config, psp_reference) + _result, error_codes, raw_response = cancel(config, psp_reference) if error_codes: error_messages = get_error_messages_from_codes(action, error_codes=error_codes) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) - return PaymentResult(status=PaymentStatus.SUCCESS) + return PaymentResult(status=PaymentStatus.SUCCESS, raw_response=raw_response) def change_transaction( @@ -115,7 +116,9 @@ def change_transaction( ] } - result, error_codes = np_request(config, "patch", "/transactions/update", json=data) + result, error_codes, raw_response = np_request( + config, "patch", "/transactions/update", json=data + ) if not error_codes: status = result["authori_result"] @@ -127,23 +130,23 @@ def change_transaction( payment.order, "cancel", transaction_id, cancel_error_codes ) error_messages = result["authori_hold"] - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) - return PaymentResult( - status=PaymentStatus.SUCCESS, - ) + return PaymentResult(status=PaymentStatus.SUCCESS, raw_response=raw_response) if PRE_FULFILLMENT_ERROR_CODE in error_codes: logger.info( "Fulfillment for payment with id %s was reported", payment_information.graphql_payment_id, ) - return PaymentResult(status=PaymentStatus.FOR_REREGISTRATION) + return PaymentResult( + status=PaymentStatus.FOR_REREGISTRATION, raw_response=raw_response + ) error_messages = get_error_messages_from_codes( action=TRANSACTION_CHANGE, error_codes=error_codes ) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) def reregister_transaction_for_partial_return( @@ -171,13 +174,14 @@ def reregister_transaction_for_partial_return( ) ) - if cancel_error_codes := cancel(config, psp_reference).error_codes: + result, cancel_error_codes, raw_response = cancel(config, psp_reference) + if cancel_error_codes: error_messages = get_error_messages_from_codes( action=TRANSACTION_CANCELLATION, error_codes=cancel_error_codes ) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) - result, error_codes = register( + result, error_codes, raw_response = register( config, payment_information, format_price(billed_amount, payment_information.currency), @@ -187,7 +191,7 @@ def reregister_transaction_for_partial_return( if not error_codes: new_psp_reference = result["np_transaction_id"] - result, error_codes = report( + result, error_codes, raw_response = report( config, shipping_company_code, new_psp_reference, tracking_number ) @@ -195,16 +199,17 @@ def reregister_transaction_for_partial_return( error_messages = get_error_messages_from_codes( FULFILLMENT_REPORT, error_codes=error_codes ) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) return PaymentResult( status=PaymentStatus.SUCCESS, psp_reference=new_psp_reference, + raw_response=raw_response, ) error_messages = get_error_messages_from_codes(action, error_codes=error_codes) - return errors_payment_result(error_messages) + return errors_payment_result(error_messages, raw_response) def report_fulfillment( @@ -216,7 +221,7 @@ def report_fulfillment( """ shipping_company_code = get_shipping_company_code(config, fulfillment) - result, error_codes = report( + _result, error_codes, _raw_response = report( config, shipping_company_code, payment.psp_reference, diff --git a/saleor/payment/gateways/np_atobarai/api_helpers.py b/saleor/payment/gateways/np_atobarai/api_helpers.py index 2022ab7a5f0..a448cc4d95f 100644 --- a/saleor/payment/gateways/np_atobarai/api_helpers.py +++ b/saleor/payment/gateways/np_atobarai/api_helpers.py @@ -72,15 +72,16 @@ def _request( def np_request( config: "ApiConfig", method: str, path: str = "", json: Optional[dict] = None ) -> NPResponse: + response_data = {} try: response = _request(config, method, path, json) response_data = response.json() if "errors" in response_data: - return NPResponse({}, response_data["errors"][0]["codes"]) - return NPResponse(response_data["results"][0], []) + return NPResponse({}, response_data["errors"][0]["codes"], response_data) + return NPResponse(response_data["results"][0], [], response_data) except requests.RequestException: logger.warning("Cannot connect to NP Atobarai.", exc_info=True) - return NPResponse({}, [NP_CONNECTION_ERROR]) + return NPResponse({}, [NP_CONNECTION_ERROR], response_data) def handle_unrecoverable_state( diff --git a/saleor/payment/gateways/np_atobarai/api_types.py b/saleor/payment/gateways/np_atobarai/api_types.py index 549dd58a8df..a491e40a9a8 100644 --- a/saleor/payment/gateways/np_atobarai/api_types.py +++ b/saleor/payment/gateways/np_atobarai/api_types.py @@ -18,10 +18,11 @@ class NPResponse(NamedTuple): result: dict error_codes: List[str] + raw_response: dict def error_np_response(error_message: str) -> NPResponse: - return NPResponse({}, [error_message]) + return NPResponse({}, [error_message], {}) @dataclass @@ -51,11 +52,15 @@ class PaymentResult: def error_payment_result(error_message: str) -> PaymentResult: - return PaymentResult(status=PaymentStatus.FAILED, errors=[error_message]) + return PaymentResult( + status=PaymentStatus.FAILED, errors=[error_message], raw_response={} + ) -def errors_payment_result(errors: List[str]) -> PaymentResult: - return PaymentResult(status=PaymentStatus.FAILED, errors=errors) +def errors_payment_result(errors: List[str], response: dict) -> PaymentResult: + return PaymentResult( + status=PaymentStatus.FAILED, errors=errors, raw_response=response + ) def get_api_config(connection_params: dict) -> ApiConfig: diff --git a/saleor/payment/gateways/np_atobarai/tests/test_api.py b/saleor/payment/gateways/np_atobarai/tests/test_api.py index 633056237ff..52517532a81 100644 --- a/saleor/payment/gateways/np_atobarai/tests/test_api.py +++ b/saleor/payment/gateways/np_atobarai/tests/test_api.py @@ -30,10 +30,11 @@ def test_refund_payment( payment_data.amount = payment_dummy.captured_amount psp_reference = "18121200001" payment_data.psp_reference = psp_reference + response_value = {"results": [{"np_transaction_id": psp_reference}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"results": [{"np_transaction_id": psp_reference}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -42,6 +43,7 @@ def test_refund_payment( # then assert gateway_response.is_success + assert gateway_response.raw_response == response_value def test_refund_payment_no_order(np_atobarai_plugin, np_payment_data, payment_dummy): @@ -59,17 +61,18 @@ def test_refund_payment_no_order(np_atobarai_plugin, np_payment_data, payment_du @patch("saleor.payment.gateways.np_atobarai.api_helpers.requests.request") -def test_refund_payment_payment_not_created( +def test_refund_payment_no_psp_reference_payment_not_created( mocked_request, np_atobarai_plugin, np_payment_data, payment_dummy ): # given plugin = np_atobarai_plugin() payment_data = np_payment_data payment_data.amount = payment_dummy.captured_amount + response_value = {"results": [{"np_transaction_id": "18121200001"}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"results": [{"np_transaction_id": "18121200001"}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -78,6 +81,7 @@ def test_refund_payment_payment_not_created( # then assert not gateway_response.is_success + assert not gateway_response.raw_response @patch.object(HTTPSession, "request") @@ -210,10 +214,11 @@ def test_report_fulfillment(mocked_request, config, fulfillment, payment_dummy): psp_reference = "18121200001" payment_dummy.psp_reference = psp_reference payment_dummy.save(update_fields=["psp_reference"]) + response_value = {"results": [{"np_transaction_id": psp_reference}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"results": [{"np_transaction_id": psp_reference}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -235,10 +240,11 @@ def test_report_fulfillment_invalid_shipping_company_code( psp_reference = "18121200001" payment_dummy.psp_reference = psp_reference payment_dummy.save(update_fields=["psp_reference"]) + response_value = {"results": [{"np_transaction_id": psp_reference}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"results": [{"np_transaction_id": psp_reference}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response fulfillment.store_value_in_private_metadata( @@ -298,10 +304,11 @@ def test_report_fulfillment_np_errors( payment_dummy.psp_reference = psp_reference payment_dummy.save(update_fields=["psp_reference"]) error_codes = ["EPRO0101", "EPRO0102"] + response_value = {"errors": [{"codes": error_codes}]} response = Mock( spec=requests.Response, status_code=400, - json=Mock(return_value={"errors": [{"codes": error_codes}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -348,10 +355,11 @@ def test_report_fulfillment_already_captured( ): # given payment_dummy.psp_reference = "123123123" + response_value = {"errors": [{"codes": ["E0100115"]}]} response = Mock( spec=requests.Response, status_code=400, - json=Mock(return_value={"errors": [{"codes": ["E0100115"]}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -496,19 +504,18 @@ def test_change_transaction_success( mocked_request, config, payment_dummy, np_payment_data ): # given + response_value = { + "results": [ + { + "authori_result": "00", + "np_transaction_id": payment_dummy.psp_reference, + } + ] + } response = Mock( spec=requests.Response, status_code=200, - json=Mock( - return_value={ - "results": [ - { - "authori_result": "00", - "np_transaction_id": payment_dummy.psp_reference, - } - ] - } - ), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -519,6 +526,8 @@ def test_change_transaction_success( # then assert payment_response.status == PaymentStatus.SUCCESS + assert payment_response.raw_response == response_value + assert payment_response.raw_response == response_value @patch("saleor.payment.gateways.np_atobarai.api.cancel") @@ -528,23 +537,22 @@ def test_change_transaction_pending( ): # given transaction_id = "123" + response_value = { + "results": [ + { + "authori_result": "10", + "np_transaction_id": transaction_id, + "authori_hold": [ + "RE009", + "REE021", + ], + } + ] + } response = Mock( spec=requests.Response, status_code=200, - json=Mock( - return_value={ - "results": [ - { - "authori_result": "10", - "np_transaction_id": transaction_id, - "authori_hold": [ - "RE009", - "REE021", - ], - } - ] - } - ), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -556,6 +564,7 @@ def test_change_transaction_pending( # then mocked_cancel.assert_called_once_with(config, transaction_id) assert payment_response.status == PaymentStatus.FAILED + assert payment_response.raw_response == response_value @patch.object(HTTPSession, "request") @@ -563,10 +572,11 @@ def test_change_transaction_post_fulfillment( mocked_request, config, payment_dummy, np_payment_data ): # given + response_value = {"errors": [{"codes": ["E0100115"]}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"errors": [{"codes": ["E0100115"]}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -577,6 +587,7 @@ def test_change_transaction_post_fulfillment( # then assert payment_response.status == PaymentStatus.FOR_REREGISTRATION + assert payment_response.raw_response == response_value @patch.object(HTTPSession, "request") @@ -584,10 +595,11 @@ def test_change_transaction_failed( mocked_request, config, payment_dummy, np_payment_data ): # given + response_value = {"errors": [{"codes": ["E0100050"]}]} response = Mock( spec=requests.Response, status_code=200, - json=Mock(return_value={"errors": [{"codes": ["E0100050"]}]}), + json=Mock(return_value=response_value), ) mocked_request.return_value = response @@ -599,6 +611,7 @@ def test_change_transaction_failed( # then assert payment_response.status == PaymentStatus.FAILED assert payment_response.errors + assert payment_response.raw_response == response_value @patch("saleor.payment.gateways.np_atobarai.api.report") @@ -617,14 +630,29 @@ def test_reregister_transaction_success( shipping_company_code = "50000" payment_dummy.psp_reference = "123" new_psp_reference = "234" - mocked_cancel.return_value = NPResponse(result={}, error_codes=[]) + mocked_cancel.return_value = NPResponse(result={}, error_codes=[], raw_response={}) + transaction_id = "18121200001" + register_response_value = { + "results": [ + { + "shop_transaction_id": "abc1234567890", + "np_transaction_id": transaction_id, + "authori_result": "00", + "authori_required_date": "2018-12-12T12:00:00+09:00", + } + ] + } mocked_register.return_value = NPResponse( result={ "np_transaction_id": new_psp_reference, }, error_codes=[], + raw_response=register_response_value, + ) + report_response_value = {"results": [{"np_transaction_id": transaction_id}]} + mocked_report.return_value = NPResponse( + result={}, error_codes=[], raw_response=report_response_value ) - mocked_report.return_value = NPResponse(result={}, error_codes=[]) # when goods, billed_amount = get_goods_with_refunds( @@ -653,6 +681,7 @@ def test_reregister_transaction_success( ) assert payment_response.status == PaymentStatus.SUCCESS assert payment_response.psp_reference == new_psp_reference + assert payment_response.raw_response == report_response_value def test_reregister_transaction_no_psp_reference(payment_dummy, np_payment_data): @@ -679,7 +708,9 @@ def test_reregister_transaction_cancel_error( # given payment_dummy.psp_reference = "123" error_codes = ["1", "2", "3"] - mocked_cancel.return_value = NPResponse(result={}, error_codes=error_codes) + mocked_cancel.return_value = NPResponse( + result={}, error_codes=error_codes, raw_response={} + ) # when payment_response = api.reregister_transaction_for_partial_return( @@ -708,14 +739,17 @@ def test_reregister_transaction_report_error( payment_dummy.psp_reference = "123" new_psp_reference = "234" error_codes = ["1", "2", "3"] - mocked_cancel.return_value = NPResponse(result={}, error_codes=[]) + mocked_cancel.return_value = NPResponse(result={}, error_codes=[], raw_response={}) mocked_register.return_value = NPResponse( result={ "np_transaction_id": new_psp_reference, }, error_codes=[], + raw_response={}, + ) + mocked_report.return_value = NPResponse( + result={}, error_codes=error_codes, raw_response={} ) - mocked_report.return_value = NPResponse(result={}, error_codes=error_codes) # when payment_response = api.reregister_transaction_for_partial_return( @@ -744,9 +778,11 @@ def test_reregister_transaction_general_error( ): # given payment_dummy.psp_reference = "123" - mocked_cancel.return_value = NPResponse(result={}, error_codes=[]) + mocked_cancel.return_value = NPResponse(result={}, error_codes=[], raw_response={}) error_codes = ["1", "2", "3"] - mocked_register.return_value = NPResponse(result={}, error_codes=error_codes) + mocked_register.return_value = NPResponse( + result={}, error_codes=error_codes, raw_response={} + ) # when payment_response = api.reregister_transaction_for_partial_return( @@ -771,8 +807,10 @@ def test_register_transaction_pending( "np_transaction_id": transaction_id, "authori_hold": errors, } - mocked_register.return_value = NPResponse(result=register_result, error_codes=[]) - mocked_cancel.return_value = NPResponse(result={}, error_codes=[]) + mocked_register.return_value = NPResponse( + result=register_result, error_codes=[], raw_response={} + ) + mocked_cancel.return_value = NPResponse(result={}, error_codes=[], raw_response={}) # when payment_response = api.register_transaction(config, np_payment_data) @@ -797,9 +835,13 @@ def test_register_transaction_pending_unrecoverable( "np_transaction_id": transaction_id, "authori_hold": errors, } - mocked_register.return_value = NPResponse(result=register_result, error_codes=[]) + mocked_register.return_value = NPResponse( + result=register_result, error_codes=[], raw_response=register_result + ) cancel_errors = ["1", "2", "3"] - mocked_cancel.return_value = NPResponse(result={}, error_codes=cancel_errors) + mocked_cancel.return_value = NPResponse( + result={}, error_codes=cancel_errors, raw_response={} + ) # when payment_response = api.register_transaction(config, np_payment_data) @@ -810,6 +852,7 @@ def test_register_transaction_pending_unrecoverable( assert payment_response.status == PaymentStatus.PENDING assert payment_response.errors == [f"TR#{e}" for e in errors] assert caplog.record_tuples == [(ANY, logging.ERROR, ANY)] + assert payment_response.raw_response == register_result def test_cancel_transaction_no_payment(np_payment_data): From a429349d7e5dffcca040d604b0f22078c8f7899f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Szyma=C5=84ski?= Date: Wed, 8 May 2024 11:47:04 +0200 Subject: [PATCH 4/5] Fix transactionUpdate/transactionCreate with empty metadata (#15888) --- .../transaction/transaction_create.py | 4 +- .../mutations/test_transaction_create.py | 118 ++++++++++++++++++ .../mutations/test_transaction_update.py | 52 ++++++++ 3 files changed, 173 insertions(+), 1 deletion(-) diff --git a/saleor/graphql/payment/mutations/transaction/transaction_create.py b/saleor/graphql/payment/mutations/transaction/transaction_create.py index c9de09d630f..ac76be8ac1f 100644 --- a/saleor/graphql/payment/mutations/transaction/transaction_create.py +++ b/saleor/graphql/payment/mutations/transaction/transaction_create.py @@ -136,8 +136,10 @@ def validate_external_url(cls, external_url: Optional[str], error_code: str): @classmethod def validate_metadata_keys( # type: ignore[override] - cls, metadata_list: List[dict], field_name, error_code + cls, metadata_list: Optional[List[dict]], field_name, error_code ): + if not metadata_list: + return if metadata_contains_empty_key(metadata_list): raise ValidationError( { diff --git a/saleor/graphql/payment/tests/mutations/test_transaction_create.py b/saleor/graphql/payment/tests/mutations/test_transaction_create.py index 9c0395174ad..2855acfd153 100644 --- a/saleor/graphql/payment/tests/mutations/test_transaction_create.py +++ b/saleor/graphql/payment/tests/mutations/test_transaction_create.py @@ -195,6 +195,63 @@ def test_transaction_create_for_order_by_app( assert transaction.external_url == external_url +def test_transaction_create_for_order_by_app_metadata_null_value( + order_with_lines, permission_manage_payments, app_api_client +): + # given + name = "Credit Card" + psp_reference = "PSP reference - 123" + available_actions = [ + TransactionActionEnum.CHARGE.name, + TransactionActionEnum.CHARGE.name, + ] + authorized_value = Decimal("10") + external_url = f"http://{TEST_SERVER_DOMAIN}/external-url" + + variables = { + "id": graphene.Node.to_global_id("Order", order_with_lines.pk), + "transaction": { + "name": name, + "pspReference": psp_reference, + "availableActions": available_actions, + "amountAuthorized": { + "amount": authorized_value, + "currency": "USD", + }, + "metadata": None, + "privateMetadata": None, + "externalUrl": external_url, + }, + } + + # when + response = app_api_client.post_graphql( + MUTATION_TRANSACTION_CREATE, variables, permissions=[permission_manage_payments] + ) + + # then + available_actions = list(set(available_actions)) + + transaction = order_with_lines.payment_transactions.first() + content = get_graphql_content(response) + data = content["data"]["transactionCreate"]["transaction"] + assert data["actions"] == available_actions + assert data["pspReference"] == psp_reference + assert data["authorizedAmount"]["amount"] == authorized_value + assert data["externalUrl"] == external_url + assert data["createdBy"]["id"] == to_global_id_or_none(app_api_client.app) + + assert available_actions == list(map(str.upper, transaction.available_actions)) + assert psp_reference == transaction.psp_reference + assert authorized_value == transaction.authorized_value + assert transaction.metadata == {} + assert transaction.private_metadata == {} + assert transaction.app_identifier == app_api_client.app.identifier + assert transaction.app == app_api_client.app + assert transaction.user is None + assert transaction.external_url == external_url + + def test_transaction_create_for_order_updates_order_total_authorized_by_app( order_with_lines, permission_manage_payments, app_api_client ): @@ -351,6 +408,67 @@ def test_transaction_create_for_checkout_by_app( assert transaction.user is None +def test_transaction_create_for_checkout_by_app_metadata_null_value( + checkout_with_items, permission_manage_payments, app_api_client +): + # given + name = "Credit Card" + psp_reference = "PSP reference - 123" + available_actions = [ + TransactionActionEnum.CHARGE.name, + TransactionActionEnum.CHARGE.name, + ] + authorized_value = Decimal("10") + external_url = f"http://{TEST_SERVER_DOMAIN}/external-url" + + variables = { + "id": graphene.Node.to_global_id("Checkout", checkout_with_items.pk), + "transaction": { + "name": name, + "pspReference": psp_reference, + "availableActions": available_actions, + "amountAuthorized": { + "amount": authorized_value, + "currency": "USD", + }, + "metadata": None, + "privateMetadata": None, + "externalUrl": external_url, + }, + } + + # when + response = app_api_client.post_graphql( + MUTATION_TRANSACTION_CREATE, variables, permissions=[permission_manage_payments] + ) + + # then + checkout_with_items.refresh_from_db() + assert checkout_with_items.charge_status == CheckoutChargeStatus.NONE + assert checkout_with_items.authorize_status == CheckoutAuthorizeStatus.PARTIAL + + available_actions = list(set(available_actions)) + + transaction = checkout_with_items.payment_transactions.first() + content = get_graphql_content(response) + data = content["data"]["transactionCreate"]["transaction"] + assert data["actions"] == available_actions + assert data["pspReference"] == psp_reference + assert data["authorizedAmount"]["amount"] == authorized_value + assert data["externalUrl"] == external_url + assert data["createdBy"]["id"] == to_global_id_or_none(app_api_client.app) + + assert available_actions == list(map(str.upper, transaction.available_actions)) + assert psp_reference == transaction.psp_reference + assert authorized_value == transaction.authorized_value + assert transaction.metadata == {} + assert transaction.private_metadata == {} + assert transaction.external_url == external_url + assert transaction.app_identifier == app_api_client.app.identifier + assert transaction.app == app_api_client.app + assert transaction.user is None + + @pytest.mark.parametrize( "amount_field_name, amount_db_field", [ diff --git a/saleor/graphql/payment/tests/mutations/test_transaction_update.py b/saleor/graphql/payment/tests/mutations/test_transaction_update.py index 73f3cdd1a0b..306ce482830 100644 --- a/saleor/graphql/payment/tests/mutations/test_transaction_update.py +++ b/saleor/graphql/payment/tests/mutations/test_transaction_update.py @@ -241,6 +241,31 @@ def test_transaction_update_metadata_by_app( assert transaction_item_created_by_app.metadata == {meta_key: meta_value} +def test_transaction_update_metadata_by_app_null_value( + transaction_item_created_by_app, permission_manage_payments, app_api_client +): + # given + transaction = transaction_item_created_by_app + + variables = { + "id": graphene.Node.to_global_id("TransactionItem", transaction.token), + "transaction": { + "metadata": None, + }, + } + + # when + response = app_api_client.post_graphql( + MUTATION_TRANSACTION_UPDATE, variables, permissions=[permission_manage_payments] + ) + + # then + transaction.refresh_from_db() + content = get_graphql_content(response) + data = content["data"]["transactionUpdate"]["transaction"] + assert len(data["metadata"]) == 0 + + def test_transaction_update_metadata_incorrect_key_by_app( transaction_item_created_by_app, permission_manage_payments, app_api_client ): @@ -300,6 +325,33 @@ def test_transaction_update_private_metadata_by_app( assert transaction_item_created_by_app.private_metadata == {meta_key: meta_value} +def test_transaction_update_private_metadata_by_app_null_value( + transaction_item_created_by_app, permission_manage_payments, app_api_client +): + # given + transaction = transaction_item_created_by_app + transaction.private_metadata = {"key": "value"} + transaction.save(update_fields=["private_metadata"]) + + variables = { + "id": graphene.Node.to_global_id("TransactionItem", transaction.token), + "transaction": { + "privateMetadata": None, + }, + } + + # when + response = app_api_client.post_graphql( + MUTATION_TRANSACTION_UPDATE, variables, permissions=[permission_manage_payments] + ) + + # then + transaction.refresh_from_db() + content = get_graphql_content(response) + data = content["data"]["transactionUpdate"]["transaction"] + assert len(data["privateMetadata"]) == 1 + + def test_transaction_update_private_metadata_incorrect_key_by_app( transaction_item_created_by_app, permission_manage_payments, app_api_client ): From cd445a13f231bacd6760da2897f6f25cef042c2d Mon Sep 17 00:00:00 2001 From: Artur <31221055+Air-t@users.noreply.github.com> Date: Fri, 10 May 2024 12:22:39 +0200 Subject: [PATCH 5/5] add missing filtering for active plugins on manager (#15946) --- saleor/plugins/manager.py | 15 ++++++------ saleor/plugins/tests/test_manager.py | 36 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/saleor/plugins/manager.py b/saleor/plugins/manager.py index 26ae076e64a..1b316f8ca78 100644 --- a/saleor/plugins/manager.py +++ b/saleor/plugins/manager.py @@ -1798,13 +1798,14 @@ def __run_plugin_method_until_first_success( *args, channel_slug: Optional[str] = None, ): - plugins = self.get_plugins(channel_slug=channel_slug) - for plugin in plugins: - result = self.__run_method_on_single_plugin( - plugin, method_name, None, *args - ) - if result is not None: - return result + plugins = self.get_plugins(channel_slug=channel_slug, active_only=True) + if plugins: + for plugin in plugins: + result = self.__run_method_on_single_plugin( + plugin, method_name, None, *args + ) + if result is not None: + return result return None def _get_all_plugin_configs(self): diff --git a/saleor/plugins/tests/test_manager.py b/saleor/plugins/tests/test_manager.py index 664a7a9aee0..07e4a100e5e 100644 --- a/saleor/plugins/tests/test_manager.py +++ b/saleor/plugins/tests/test_manager.py @@ -1546,3 +1546,39 @@ def test_plugin_manager__get_channel_map( channel_JPY.pk: channel_JPY, other_channel_USD.pk: other_channel_USD, } + + +@pytest.mark.parametrize( + ("plugins", "calls"), + [ + ([], 0), + (["saleor.plugins.tests.sample_plugins.PluginInactive"], 0), + (["saleor.plugins.tests.sample_plugins.PluginSample"], 1), + ( + [ + "saleor.plugins.tests.sample_plugins.PluginInactive", + "saleor.plugins.tests.sample_plugins.PluginSample", + ], + 1, + ), + ], +) +def test_run_plugin_method_until_first_success_for_active_plugins_only( + channel_USD, plugins, calls +): + # given + manager = PluginsManager(plugins=plugins) + + # when + with patch.object( + PluginsManager, + "_PluginsManager__run_method_on_single_plugin", + return_value=None, + ) as mock_run_method: + result = manager._PluginsManager__run_plugin_method_until_first_success( + "some_method", channel_slug=None + ) + + # then + assert result is None + assert mock_run_method.call_count == calls