Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filtering attributes by collection #3508

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ All notable, unreleased changes to this project will be documented in this file.
- Change "Feature on Homepage" switch behavior - #3481 by @dominik-zeglen
- Expand payment section in order view - #3502 by @dominik-zeglen
- Fixed migrations for default currency - #3235 by @bykof
- Filter attributes by collection in API - #3508 by @maarcingebala
40 changes: 27 additions & 13 deletions saleor/graphql/product/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,38 @@
ATTRIBUTES_SEARCH_FIELDS = ('name', 'slug')


def resolve_attributes(info, category_id, query):
def _filter_attributes_by_product_types(attribute_qs, product_qs):
product_types = set(product_qs.values_list('product_type_id', flat=True))
return attribute_qs.filter(
Q(product_type__in=product_types)
| Q(product_variant_type__in=product_types))


def resolve_attributes(info, category_id=None, collection_id=None, query=None):
qs = models.Attribute.objects.all()
qs = filter_by_query_param(qs, query, ATTRIBUTES_SEARCH_FIELDS)

if category_id:
# Get attributes that are used with product types
# within the given category.
# Filter attributes by product types belonging to the given category.
category = graphene.Node.get_node_from_global_id(
info, category_id, Category)
if category is None:
return qs.none()
tree = category.get_descendants(include_self=True)
product_types = {
obj[0]
for obj in models.Product.objects.filter(
category__in=tree).values_list('product_type_id')}
qs = qs.filter(
Q(product_type__in=product_types)
| Q(product_variant_type__in=product_types))
if category:
tree = category.get_descendants(include_self=True)
product_qs = models.Product.objects.filter(category__in=tree)
qs = _filter_attributes_by_product_types(qs, product_qs)
else:
qs = qs.none()

if collection_id:
# Filter attributes by product types belonging to the given collection.
collection = graphene.Node.get_node_from_global_id(
info, collection_id, Collection)
if collection:
product_qs = collection.products.all()
qs = _filter_attributes_by_product_types(qs, product_qs)
else:
qs = qs.none()

qs = qs.order_by('name')
qs = qs.distinct()
return gql_optimizer.query(qs, info)
Expand Down
26 changes: 20 additions & 6 deletions saleor/graphql/product/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from textwrap import dedent

import graphene
from graphql_jwt.decorators import permission_required

Expand Down Expand Up @@ -28,9 +30,16 @@
class ProductQueries(graphene.ObjectType):
attributes = PrefetchingConnectionField(
Attribute,
description='List of the shop\'s attributes.',
query=graphene.String(description=DESCRIPTIONS['attributes']),
in_category=graphene.Argument(graphene.ID),
description='List of the shop\'s attributes.')
in_category=graphene.Argument(
graphene.ID, description=dedent(
'''Return attributes for products belonging to the given
category.''')),
in_collection=graphene.Argument(
graphene.ID, description=dedent(
'''Return attributes for products belonging to the given
collection.''')),)
categories = PrefetchingConnectionField(
Category, query=graphene.String(
description=DESCRIPTIONS['category']),
Expand Down Expand Up @@ -58,9 +67,12 @@ class ProductQueries(graphene.ObjectType):
collections=graphene.List(
graphene.ID, description='Filter products by collections.'),
price_lte=graphene.Float(
description='Filter by price less than or equal to the given value.'),
description=dedent(
'''Filter by price less than or equal to the given value.''')),
price_gte=graphene.Float(
description='Filter by price greater than or equal to the given value.'),
description=dedent(
'''
Filter by price greater than or equal to the given value.''')),
sort_by=graphene.Argument(
ProductOrder, description='Sort products.'),
stock_availability=graphene.Argument(
Expand All @@ -85,8 +97,10 @@ class ProductQueries(graphene.ObjectType):
ReportingPeriod, required=True, description='Span of time.'),
description='List of top selling products.')

def resolve_attributes(self, info, in_category=None, query=None, **kwargs):
return resolve_attributes(info, in_category, query)
def resolve_attributes(
self, info, in_category=None, in_collection=None, query=None,
**kwargs):
return resolve_attributes(info, in_category, in_collection, query)

def resolve_categories(self, info, level=None, query=None, **kwargs):
return resolve_categories(info, level=level, query=query)
Expand Down
2 changes: 1 addition & 1 deletion saleor/graphql/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1576,7 +1576,7 @@ type Query {
shop: Shop
shippingZone(id: ID!): ShippingZone
shippingZones(before: String, after: String, first: Int, last: Int): ShippingZoneCountableConnection
attributes(query: String, inCategory: ID, before: String, after: String, first: Int, last: Int): AttributeCountableConnection
attributes(query: String, inCategory: ID, inCollection: ID, before: String, after: String, first: Int, last: Int): AttributeCountableConnection
categories(query: String, level: Int, before: String, after: String, first: Int, last: Int): CategoryCountableConnection
category(id: ID!): Category
collection(id: ID!): Collection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export interface CollectionUpdateWithHomepage_homepageCollectionUpdate_shop {

export interface CollectionUpdateWithHomepage_homepageCollectionUpdate {
__typename: "HomepageCollectionUpdate";
errors: (CollectionUpdateWithHomepage_homepageCollectionUpdate_errors | null)[] | null;
errors: CollectionUpdateWithHomepage_homepageCollectionUpdate_errors[] | null;
shop: CollectionUpdateWithHomepage_homepageCollectionUpdate_shop | null;
}

Expand Down Expand Up @@ -54,7 +54,7 @@ export interface CollectionUpdateWithHomepage_collectionUpdate_collection {

export interface CollectionUpdateWithHomepage_collectionUpdate {
__typename: "CollectionUpdate";
errors: (CollectionUpdateWithHomepage_collectionUpdate_errors | null)[] | null;
errors: CollectionUpdateWithHomepage_collectionUpdate_errors[] | null;
collection: CollectionUpdateWithHomepage_collectionUpdate_collection | null;
}

Expand Down
34 changes: 33 additions & 1 deletion tests/api/test_attributes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import graphene
import pytest
from django.db.models import Q
from django.template.defaultfilters import slugify
from tests.api.utils import get_graphql_content

from saleor.graphql.product.types import (
AttributeTypeEnum, AttributeValueType, resolve_attribute_value_type)
from saleor.graphql.product.utils import attributes_to_hstore
from saleor.product.models import Attribute, AttributeValue, Category
from tests.api.utils import get_graphql_content


def test_attributes_to_hstore(product, color_attribute):
Expand Down Expand Up @@ -87,6 +88,37 @@ def test_attributes_in_category_query(user_api_client, product):
assert len(attributes_data) == Attribute.objects.count()


def test_attributes_in_collection_query(user_api_client, sale):
product_types = set(
sale.products.all().values_list('product_type_id', flat=True))
expected_attrs = Attribute.objects.filter(
Q(product_type__in=product_types)
| Q(product_variant_type__in=product_types))

query = """
query {
attributes(inCollection: "%(collection_id)s", first: 20) {
edges {
node {
id
name
slug
values {
id
name
slug
}
}
}
}
}
""" % {'collection_id': graphene.Node.to_global_id('Collection', sale.id)}
response = user_api_client.post_graphql(query)
content = get_graphql_content(response)
attributes_data = content['data']['attributes']['edges']
assert len(attributes_data) == len(expected_attrs)


CREATE_ATTRIBUTES_QUERY = """
mutation createAttribute(
$name: String!, $values: [AttributeValueCreateInput],
Expand Down