From 3fc701158528c3808a8f479846b1f033b4b8df2c Mon Sep 17 00:00:00 2001 From: fowczrek Date: Thu, 29 Dec 2022 12:54:30 +0100 Subject: [PATCH] Use replicas to fetching disount info --- saleor/discount/utils.py | 48 ++++++++++++++++++-------- saleor/graphql/discount/dataloaders.py | 12 ++++--- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/saleor/discount/utils.py b/saleor/discount/utils.py index 47b3d917e67..d1790dad324 100644 --- a/saleor/discount/utils.py +++ b/saleor/discount/utils.py @@ -14,6 +14,7 @@ cast, ) +from django.conf import settings from django.db.models import F from django.utils import timezone from prices import Money, TaxedMoney @@ -100,7 +101,7 @@ def get_product_discounts( collections: Iterable["Collection"], discounts: Iterable[DiscountInfo], channel: "Channel", - variant_id: Optional[int] = None + variant_id: Optional[int] = None, ) -> Iterator[Tuple[int, Callable]]: """Return sale ids, discount values for all discounts applicable to a product.""" product_collections = set(pc.id for pc in collections) @@ -120,7 +121,7 @@ def get_sale_id_with_min_price( collections: Iterable["Collection"], discounts: Optional[Iterable[DiscountInfo]], channel: "Channel", - variant_id: Optional[int] = None + variant_id: Optional[int] = None, ) -> Tuple[Optional[int], Money]: """Return a sale_id and minimum product's price.""" available_discounts = [ @@ -150,7 +151,7 @@ def calculate_discounted_price( collections: Iterable["Collection"], discounts: Optional[Iterable[DiscountInfo]], channel: "Channel", - variant_id: Optional[int] = None + variant_id: Optional[int] = None, ) -> Money: """Return minimum product's price of all prices with discounts applied.""" if discounts: @@ -172,7 +173,7 @@ def get_sale_id_applied_as_a_discount( collections: Iterable["Collection"], discounts: Optional[Iterable[DiscountInfo]], channel: "Channel", - variant_id: Optional[int] = None + variant_id: Optional[int] = None, ) -> Optional[int]: """Return an ID of Sale applied to product.""" if not discounts: @@ -257,11 +258,15 @@ def get_products_voucher_discount( return total_amount -def fetch_categories(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: +def fetch_categories( + sale_pks: Iterable[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> Dict[int, Set[int]]: from ..product.models import Category categories = ( - Sale.categories.through.objects.filter(sale_id__in=sale_pks) + Sale.categories.through.objects.using(database_connection_name) + .filter(sale_id__in=sale_pks) .order_by("id") .values_list("sale_id", "category_id") ) @@ -278,9 +283,13 @@ def fetch_categories(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: return subcategory_map -def fetch_collections(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: +def fetch_collections( + sale_pks: Iterable[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> Dict[int, Set[int]]: collections = ( - Sale.collections.through.objects.filter(sale_id__in=sale_pks) + Sale.collections.through.objects.using(database_connection_name) + .filter(sale_id__in=sale_pks) .order_by("id") .values_list("sale_id", "collection_id") ) @@ -290,9 +299,13 @@ def fetch_collections(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: return collection_map -def fetch_products(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: +def fetch_products( + sale_pks: Iterable[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> Dict[int, Set[int]]: products = ( - Sale.products.through.objects.filter(sale_id__in=sale_pks) + Sale.products.through.objects.using(database_connection_name) + .filter(sale_id__in=sale_pks) .order_by("id") .values_list("sale_id", "product_id") ) @@ -302,9 +315,13 @@ def fetch_products(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: return product_map -def fetch_variants(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: +def fetch_variants( + sale_pks: Iterable[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, +) -> Dict[int, Set[int]]: variants = ( - Sale.variants.through.objects.filter(sale_id__in=sale_pks) + Sale.variants.through.objects.using(database_connection_name) + .filter(sale_id__in=sale_pks) .order_by("id") .values_list("sale_id", "productvariant_id") ) @@ -316,9 +333,12 @@ def fetch_variants(sale_pks: Iterable[str]) -> Dict[int, Set[int]]: def fetch_sale_channel_listings( sale_pks: Iterable[str], + database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME, ): - channel_listings = SaleChannelListing.objects.filter(sale_id__in=sale_pks).annotate( - channel_slug=F("channel__slug") + channel_listings = ( + SaleChannelListing.objects.using(database_connection_name) + .filter(sale_id__in=sale_pks) + .annotate(channel_slug=F("channel__slug")) ) channel_listings_map: Dict[int, Dict[str, SaleChannelListing]] = defaultdict(dict) for channel_listing in channel_listings: diff --git a/saleor/graphql/discount/dataloaders.py b/saleor/graphql/discount/dataloaders.py index 8cfe2c51e7c..adde23a558e 100644 --- a/saleor/graphql/discount/dataloaders.py +++ b/saleor/graphql/discount/dataloaders.py @@ -33,11 +33,13 @@ def batch_load(self, keys): for datetime in keys } pks = {s.pk for d, ss in sales_map.items() for s in ss} - collections = fetch_collections(pks) - channel_listings = fetch_sale_channel_listings(pks) - products = fetch_products(pks) - categories = fetch_categories(pks) - variants = fetch_variants(pks) + collections = fetch_collections(pks, self.database_connection_name) + channel_listings = fetch_sale_channel_listings( + pks, self.database_connection_name + ) + products = fetch_products(pks, self.database_connection_name) + categories = fetch_categories(pks, self.database_connection_name) + variants = fetch_variants(pks, self.database_connection_name) return [ [