Skip to content

Commit

Permalink
Use replicas to fetching disount info
Browse files Browse the repository at this point in the history
  • Loading branch information
fowczarek committed Dec 30, 2022
1 parent 8664cf6 commit 3fc7011
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
48 changes: 34 additions & 14 deletions saleor/discount/utils.py
Expand Up @@ -14,6 +14,7 @@
cast,
)

from django.conf import settings
from django.db.models import F
from django.utils import timezone
from prices import Money, TaxedMoney
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_product_discounts(
collections: Iterable["Collection"],
discounts: Iterable[DiscountInfo],
channel: "Channel",
variant_id: Optional[int] = None
variant_id: Optional[int] = None,
) -> Iterator[Tuple[int, Callable]]:
"""Return sale ids, discount values for all discounts applicable to a product."""
product_collections = set(pc.id for pc in collections)
Expand All @@ -120,7 +121,7 @@ def get_sale_id_with_min_price(
collections: Iterable["Collection"],
discounts: Optional[Iterable[DiscountInfo]],
channel: "Channel",
variant_id: Optional[int] = None
variant_id: Optional[int] = None,
) -> Tuple[Optional[int], Money]:
"""Return a sale_id and minimum product's price."""
available_discounts = [
Expand Down Expand Up @@ -150,7 +151,7 @@ def calculate_discounted_price(
collections: Iterable["Collection"],
discounts: Optional[Iterable[DiscountInfo]],
channel: "Channel",
variant_id: Optional[int] = None
variant_id: Optional[int] = None,
) -> Money:
"""Return minimum product's price of all prices with discounts applied."""
if discounts:
Expand All @@ -172,7 +173,7 @@ def get_sale_id_applied_as_a_discount(
collections: Iterable["Collection"],
discounts: Optional[Iterable[DiscountInfo]],
channel: "Channel",
variant_id: Optional[int] = None
variant_id: Optional[int] = None,
) -> Optional[int]:
"""Return an ID of Sale applied to product."""
if not discounts:
Expand Down Expand Up @@ -257,11 +258,15 @@ def get_products_voucher_discount(
return total_amount


def fetch_categories(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
def fetch_categories(
sale_pks: Iterable[str],
database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME,
) -> Dict[int, Set[int]]:
from ..product.models import Category

categories = (
Sale.categories.through.objects.filter(sale_id__in=sale_pks)
Sale.categories.through.objects.using(database_connection_name)
.filter(sale_id__in=sale_pks)
.order_by("id")
.values_list("sale_id", "category_id")
)
Expand All @@ -278,9 +283,13 @@ def fetch_categories(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
return subcategory_map


def fetch_collections(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
def fetch_collections(
sale_pks: Iterable[str],
database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME,
) -> Dict[int, Set[int]]:
collections = (
Sale.collections.through.objects.filter(sale_id__in=sale_pks)
Sale.collections.through.objects.using(database_connection_name)
.filter(sale_id__in=sale_pks)
.order_by("id")
.values_list("sale_id", "collection_id")
)
Expand All @@ -290,9 +299,13 @@ def fetch_collections(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
return collection_map


def fetch_products(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
def fetch_products(
sale_pks: Iterable[str],
database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME,
) -> Dict[int, Set[int]]:
products = (
Sale.products.through.objects.filter(sale_id__in=sale_pks)
Sale.products.through.objects.using(database_connection_name)
.filter(sale_id__in=sale_pks)
.order_by("id")
.values_list("sale_id", "product_id")
)
Expand All @@ -302,9 +315,13 @@ def fetch_products(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
return product_map


def fetch_variants(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:
def fetch_variants(
sale_pks: Iterable[str],
database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME,
) -> Dict[int, Set[int]]:
variants = (
Sale.variants.through.objects.filter(sale_id__in=sale_pks)
Sale.variants.through.objects.using(database_connection_name)
.filter(sale_id__in=sale_pks)
.order_by("id")
.values_list("sale_id", "productvariant_id")
)
Expand All @@ -316,9 +333,12 @@ def fetch_variants(sale_pks: Iterable[str]) -> Dict[int, Set[int]]:

def fetch_sale_channel_listings(
sale_pks: Iterable[str],
database_connection_name: str = settings.DATABASE_CONNECTION_DEFAULT_NAME,
):
channel_listings = SaleChannelListing.objects.filter(sale_id__in=sale_pks).annotate(
channel_slug=F("channel__slug")
channel_listings = (
SaleChannelListing.objects.using(database_connection_name)
.filter(sale_id__in=sale_pks)
.annotate(channel_slug=F("channel__slug"))
)
channel_listings_map: Dict[int, Dict[str, SaleChannelListing]] = defaultdict(dict)
for channel_listing in channel_listings:
Expand Down
12 changes: 7 additions & 5 deletions saleor/graphql/discount/dataloaders.py
Expand Up @@ -33,11 +33,13 @@ def batch_load(self, keys):
for datetime in keys
}
pks = {s.pk for d, ss in sales_map.items() for s in ss}
collections = fetch_collections(pks)
channel_listings = fetch_sale_channel_listings(pks)
products = fetch_products(pks)
categories = fetch_categories(pks)
variants = fetch_variants(pks)
collections = fetch_collections(pks, self.database_connection_name)
channel_listings = fetch_sale_channel_listings(
pks, self.database_connection_name
)
products = fetch_products(pks, self.database_connection_name)
categories = fetch_categories(pks, self.database_connection_name)
variants = fetch_variants(pks, self.database_connection_name)

return [
[
Expand Down

0 comments on commit 3fc7011

Please sign in to comment.