diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py index 6f4d2df2984..58004398c6b 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_create.py @@ -14,9 +14,7 @@ from ...core.mutations import BaseMutation from ...core.types import BulkStockError, NonNullList from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..mutations.product.product_create import StockInput from ..types import ProductVariant @@ -69,9 +67,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): manager.product_variant_back_in_stock, stock, webhooks=webhooks ) - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py index 93ab24e7567..9a14165fc33 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_delete.py @@ -14,9 +14,7 @@ from ...core.types import NonNullList, StockError from ...core.validators import validate_one_of_args_is_in_mutation from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..types import ProductVariant @@ -85,9 +83,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): stocks_to_delete.delete() - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py b/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py index 2cc54d2c3f2..36ff6f7528e 100644 --- a/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py +++ b/saleor/graphql/product/bulk_mutations/product_variant_stocks_update.py @@ -15,9 +15,7 @@ from ...core.types import BulkStockError, NonNullList from ...core.validators import validate_one_of_args_is_in_mutation from ...plugins.dataloaders import get_plugin_manager_promise -from ...warehouse.dataloaders import ( - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, -) +from ...warehouse.dataloaders import StocksByProductVariantIdLoader from ...warehouse.types import Warehouse from ..mutations.product.product_create import StockInput from ..types import ProductVariant @@ -81,9 +79,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data): manager = get_plugin_manager_promise(info.context).get() cls.update_or_create_variant_stocks(variant, stocks, warehouses, manager) - StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).clear((variant.id, None, None)) + StocksByProductVariantIdLoader(info.context).clear(variant.id) variant = ChannelContext(node=variant, channel_slug=None) return cls(product_variant=variant) diff --git a/saleor/graphql/product/tests/queries/test_product_variant_query.py b/saleor/graphql/product/tests/queries/test_product_variant_query.py index c9da985ba37..fd77cb3cc72 100644 --- a/saleor/graphql/product/tests/queries/test_product_variant_query.py +++ b/saleor/graphql/product/tests/queries/test_product_variant_query.py @@ -3,6 +3,7 @@ from measurement.measures import Weight from .....core.units import WeightUnits +from .....warehouse import WarehouseClickAndCollectOption from ....core.enums import WeightUnitsEnum from ....tests.utils import assert_no_permission, get_graphql_content @@ -167,6 +168,41 @@ def test_fetch_variant_no_stocks( ) +def test_fetch_variant_stocks_from_click_and_collect_warehouse( + staff_api_client, + product, + permission_manage_products, + channel_USD, +): + # given + query = QUERY_VARIANT + variant = product.variants.first() + stocks_count = variant.stocks.count() + warehouse = variant.stocks.first().warehouse + + # remove the warehouse shipping zones and mark it as click and collect + # the stocks for this warehouse should be still returned + warehouse.shipping_zones.clear() + warehouse.click_and_collect_option = WarehouseClickAndCollectOption.LOCAL_STOCK + warehouse.save(update_fields=["click_and_collect_option"]) + + variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk) + variables = {"id": variant_id, "countryCode": "EU", "channel": channel_USD.slug} + staff_api_client.user.user_permissions.add(permission_manage_products) + + # when + response = staff_api_client.post_graphql(query, variables) + + # then + content = get_graphql_content(response) + data = content["data"]["productVariant"] + assert data["name"] == variant.name + assert data["created"] == variant.created_at.isoformat() + + assert len(data["stocksByAddress"]) == stocks_count + assert not data["deprecatedStocksByCountry"] + + QUERY_PRODUCT_VARIANT_CHANNEL_LISTING = """ query ProductVariantDetails($id: ID!, $channel: String) { productVariant(id: $id, channel: $channel) { diff --git a/saleor/graphql/product/types/products.py b/saleor/graphql/product/types/products.py index afb6a272649..617936717ad 100644 --- a/saleor/graphql/product/types/products.py +++ b/saleor/graphql/product/types/products.py @@ -117,6 +117,7 @@ from ...warehouse.dataloaders import ( AvailableQuantityByProductVariantIdCountryCodeAndChannelSlugLoader, PreorderQuantityReservedByVariantChannelListingIdLoader, + StocksByProductVariantIdLoader, StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader, ) from ...warehouse.types import Stock @@ -436,9 +437,13 @@ def resolve_stocks( ): if address is not None: country_code = address.country - return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( - info.context - ).load((root.node.id, country_code, root.channel_slug)) + channle_slug = root.channel_slug + if channle_slug or country_code: + return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( # noqa: E501 + info.context + ).load((root.node.id, country_code, root.channel_slug)) + else: + return StocksByProductVariantIdLoader(info.context).load(root.node.id) @staticmethod @load_site_callback diff --git a/saleor/graphql/shipping/dataloaders.py b/saleor/graphql/shipping/dataloaders.py index b5c490a6401..85462fdb5a0 100644 --- a/saleor/graphql/shipping/dataloaders.py +++ b/saleor/graphql/shipping/dataloaders.py @@ -1,6 +1,6 @@ from collections import defaultdict -from django.db.models import Exists, F, OuterRef +from django.db.models import Exists, F, OuterRef, Q from ...channel.models import Channel from ...shipping.models import ( @@ -220,3 +220,23 @@ def map_channels(channels): .load_many({pk for pk, _ in channel_and_zone_is_pairs}) .then(map_channels) ) + + +class ShippingZonesByCountryLoader(DataLoader): + context_key = "shippingzones_by_country" + + def batch_load(self, keys): + lookup = Q() + for key in keys: + lookup |= Q(countries__contains=key) + shipping_zones = ShippingZone.objects.using( + self.database_connection_name + ).filter(lookup) + + shipping_zones_by_country = defaultdict(list) + for shipping_zone in shipping_zones: + for country_code in keys: + if country_code in shipping_zone.countries: + shipping_zones_by_country[country_code].append(shipping_zone) + + return [shipping_zones_by_country[key] for key in keys] diff --git a/saleor/graphql/warehouse/dataloaders.py b/saleor/graphql/warehouse/dataloaders.py index 1b3fae7fa94..eb517985c41 100644 --- a/saleor/graphql/warehouse/dataloaders.py +++ b/saleor/graphql/warehouse/dataloaders.py @@ -15,6 +15,7 @@ from django.db.models.functions import Coalesce from django.utils import timezone from django_stubs_ext import WithAnnotations +from promise import Promise from ...channel.models import Channel from ...product.models import ProductVariantChannelListing @@ -28,7 +29,12 @@ Warehouse, ) from ...warehouse.reservations import is_reservation_enabled +from ..channel.dataloaders import ChannelBySlugLoader from ..core.dataloaders import DataLoader +from ..shipping.dataloaders import ( + ShippingZonesByChannelIdLoader, + ShippingZonesByCountryLoader, +) from ..site.dataloaders import get_site_promise if TYPE_CHECKING: @@ -385,79 +391,183 @@ class StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( context_key = "stocks_with_available_quantity_by_productvariant_country_and_channel" def batch_load(self, keys): - # Split the list of keys by country first. A typical query will only touch - # a handful of unique countries but may access thousands of product variants - # so it's cheaper to execute one query per country. - variants_by_country_and_channel: defaultdict[ - tuple[CountryCode, str], list[int] - ] = defaultdict(list) - for variant_id, country_code, channel_slug in keys: - variants_by_country_and_channel[(country_code, channel_slug)].append( - variant_id - ) - - # For each country code execute a single query for all product variants. - stocks_by_variant_and_country: defaultdict[ - VariantIdCountryCodeChannelSlug, list[Stock] - ] = defaultdict(list) - for key, variant_ids in variants_by_country_and_channel.items(): - country_code, channel_slug = key - variant_ids_stocks = self.batch_load_stocks_by_country( - country_code, channel_slug, variant_ids - ) - for variant_id, stocks in variant_ids_stocks: - stocks_by_variant_and_country[ - (variant_id, country_code, channel_slug) - ].extend(stocks) + def with_channels(channels): + def with_shipping_zones(data): + def with_warehouses(warehouse_data): + warehouses_by_channel, warehouses_by_zone = warehouse_data + + # build maps + variant_ids_by_country_and_channel_map: defaultdict[ + tuple[CountryCode, str], list[int] + ] = defaultdict(list) + for variant_id, country_code, channel_slug in keys: + variant_ids_by_country_and_channel_map[ + (country_code, channel_slug) + ].append(variant_id) + + shipping_zones_by_channel_map = { + channel.slug: set(shipping_zones) + for shipping_zones, channel in zip( + shipping_zones_by_channel, channels + ) + } + shipping_zones_by_country_map = { + country_code: set(shipping_zones) + for shipping_zones, country_code in zip( + shipping_zones_by_country, country_codes + ) + } + warehouses_by_channel_map = { + channel.slug: set(warehouses) + for warehouses, channel in zip(warehouses_by_channel, channels) + } + warehouses_by_zone_map = { + shipping_zone_id: set(warehouses) + for warehouses, shipping_zone_id in zip( + warehouses_by_zone, shipping_zone_ids + ) + } + + # filter warehouses + warehouse_ids_by_country_and_channel_map = ( + self.get_relevant_warehouses( + variant_ids_by_country_and_channel_map, + shipping_zones_by_channel_map, + shipping_zones_by_country_map, + warehouses_by_channel_map, + warehouses_by_zone_map, + ) + ) - return [stocks_by_variant_and_country[key] for key in keys] + variant_ids = list(set(key[0] for key in keys)) + warehouse_ids = { + warehouse_id + for warehouse_ids in warehouse_ids_by_country_and_channel_map.values() # noqa: E501 + for warehouse_id in warehouse_ids + } + stocks_qs = Stock.objects.using( + self.database_connection_name + ).filter( + product_variant_id__in=variant_ids, + warehouse_id__in=warehouse_ids, + ) - def batch_load_stocks_by_country( - self, - country_code: Optional[CountryCode], - channel_slug: Optional[str], - variant_ids: Iterable[int], - ) -> Iterable[tuple[int, list[Stock]]]: - # convert to set to not return the same stocks for the same variant twice - variant_ids_set = set(variant_ids) - stocks = ( - Stock.objects.all() - .using(self.database_connection_name) - .filter(product_variant_id__in=variant_ids_set) - ) - if country_code: - stocks = stocks.filter( - warehouse__shipping_zones__countries__contains=country_code - ) - if channel_slug: - # click and collect warehouses don't have to be assigned to the shipping - # zones, the others must - stocks = stocks.filter( - Q( - warehouse__shipping_zones__channels__slug=channel_slug, - warehouse__channels__slug=channel_slug, + stocks_qs = stocks_qs.annotate_available_quantity().order_by("pk") + + results = [] + for variant_id, country_code, channel_slug in keys: + warehouse_ids = warehouse_ids_by_country_and_channel_map[ + (country_code, channel_slug) + ] + stocks = [ + stock + for stock in stocks_qs + if stock.product_variant_id == variant_id + and stock.warehouse_id in warehouse_ids + ] + results.append(stocks) + + return results + + shipping_zones_by_channel, shipping_zones_by_country = data + channel_ids = [channel.id for channel in channels] + + shipping_zone_by_channel_ids = { + shipping_zone.id + for shipping_zones in shipping_zones_by_channel + for shipping_zone in shipping_zones + } + shipping_zone_by_country_ids = { + shipping_zone.id + for shipping_zones in shipping_zones_by_country + for shipping_zone in shipping_zones + } + shipping_zone_ids = ( + shipping_zone_by_channel_ids | shipping_zone_by_country_ids ) - | Q( - warehouse__channels__slug=channel_slug, - warehouse__click_and_collect_option__in=[ - WarehouseClickAndCollectOption.LOCAL_STOCK, - WarehouseClickAndCollectOption.ALL_WAREHOUSES, - ], + + warehouses_by_channel = WarehousesByChannelIdLoader( + self.context + ).load_many(channel_ids) + warehouses_by_zone = WarehousesByShippingZoneIdLoader( + self.context + ).load_many(shipping_zone_ids) + return Promise.all([warehouses_by_channel, warehouses_by_zone]).then( + with_warehouses ) - ) - stocks = stocks.annotate_available_quantity().order_by("pk") - stocks_by_variant_id_map: defaultdict[int, list[Stock]] = defaultdict(list) - for stock in stocks: - stocks_by_variant_id_map[stock.product_variant_id].append(stock) + channel_ids = [channel.id for channel in set(channels)] + shipping_zones_by_channel = ShippingZonesByChannelIdLoader( + self.context + ).load_many(channel_ids) - return [ - ( - variant_id, - stocks_by_variant_id_map[variant_id], - ) - for variant_id in variant_ids_set - ] + country_codes = list(set(key[1] for key in keys if key[1])) + shipping_zones_by_country = ShippingZonesByCountryLoader( + self.context + ).load_many(country_codes) + + return Promise.all( + [shipping_zones_by_channel, shipping_zones_by_country] + ).then(with_shipping_zones) + + channel_slugs = list(set(key[2] for key in keys if key[2])) + return ( + ChannelBySlugLoader(self.context) + .load_many(channel_slugs) + .then(with_channels) + ) + + @staticmethod + def get_relevant_warehouses( + variant_ids_by_country_and_channel_map, + shipping_zones_by_channel_map, + shipping_zones_by_country_map, + warehouses_by_channel_map, + warehouses_by_zone_map, + ): + warehouse_ids_by_country_and_channel_map = defaultdict(list) + for ( + country_code, + channel_slug, + ), variant_ids in variant_ids_by_country_and_channel_map.items(): + warehouses = set() + warehouses_in_country = set() + # get warehouses from shipping zones in specific country + if country_code: + shipping_zones_in_country = shipping_zones_by_country_map[country_code] + for zone in shipping_zones_in_country: + warehouses_in_country |= warehouses_by_zone_map[zone.id] + + if channel_slug: + warehouses_in_channel = warehouses_by_channel_map[channel_slug] + shipping_zones_in_channel = shipping_zones_by_channel_map[channel_slug] + cc_options = [ + WarehouseClickAndCollectOption.LOCAL_STOCK, + WarehouseClickAndCollectOption.ALL_WAREHOUSES, + ] + # get click & collect warehouses available in channel + cc_warehouses_in_channel = { + warehouse + for warehouse in warehouses_in_channel + if warehouse.click_and_collect_option in cc_options + } + + # get warehouses with shipping zone, both available in channel + warehouses_with_zone_in_channel = set() + for zone in shipping_zones_in_channel: + warehouses_with_zone_in_channel |= ( + warehouses_by_zone_map[zone.id] & warehouses_in_channel + ) + + warehouses = cc_warehouses_in_channel | warehouses_with_zone_in_channel + if country_code: + warehouses &= warehouses_in_country + + warehouse_ids_by_country_and_channel_map[(country_code, channel_slug)] = [ + warehouse.id for warehouse in warehouses + ] + + return warehouse_ids_by_country_and_channel_map class StocksReservationsByCheckoutTokenLoader(DataLoader): @@ -621,3 +731,17 @@ def map_warehouses(warehouses): .load_many({pk for pk, _ in warehouse_and_shipping_zone_in_pairs}) .then(map_warehouses) ) + + +class StocksByProductVariantIdLoader(DataLoader): + context_key = "stocks_by_product_variant" + + def batch_load(self, keys): + stocks = Stock.objects.using(self.database_connection_name).filter( + product_variant_id__in=keys + ) + stocks_by_variant_id = defaultdict(list) + for stock in stocks: + stocks_by_variant_id[stock.product_variant_id].append(stock) + + return [stocks_by_variant_id[key] for key in keys]