diff --git a/django_restql/mixins.py b/django_restql/mixins.py index ddfddfb..70c04f2 100644 --- a/django_restql/mixins.py +++ b/django_restql/mixins.py @@ -4,7 +4,7 @@ Serializer, ListSerializer, ValidationError ) -from django.db.models import Prefetch +from .settings import restql_settings from django.db.models.fields.related import( ManyToOneRel, ManyToManyRel ) @@ -18,12 +18,47 @@ ) -class DynamicFieldsMixin(object): - query_param_name = "query" +class RequestQueryParserMixin(object): + @staticmethod + def get_restql_query_param_name(): + DEFAULT_QUERY_PARAM_NAME = 'query' + query_param_name = getattr( + restql_settings, + "QUERY_PARAM_NAME", + DEFAULT_QUERY_PARAM_NAME + ) + return query_param_name + + @classmethod + def has_restql_query_param(cls, request): + query_param_name = cls.get_restql_query_param_name() + return query_param_name in request.query_params + + @classmethod + def get_raw_restql_query(cls, request): + query_param_name = cls.get_restql_query_param_name() + return request.query_params[query_param_name] + + @classmethod + def get_parsed_restql_query_from_req(cls, request): + raw_query = cls.get_raw_restql_query(request) + parser = Parser(raw_query) + try: + parsed_restql_query = parser.get_parsed() + return parsed_restql_query + except SyntaxError as e: + msg = ( + "QueryFormatError: " + + e.msg + " on " + + e.text + ) + raise ValidationError(msg) from None + +class DynamicFieldsMixin(RequestQueryParserMixin): def __init__(self, *args, **kwargs): # Don't pass 'query', 'fields' and 'exclude' kwargs to the superclass - self.query = kwargs.pop('query', None) # Parsed query + self.parsed_restql_query = kwargs.pop('query', None) self.allowed_fields = kwargs.pop('fields', None) self.excluded_fields = kwargs.pop('exclude', None) self.return_pk = kwargs.pop('return_pk', False) @@ -41,29 +76,6 @@ def to_representation(self, instance): return instance.pk return super().to_representation(instance) - @classmethod - def has_query_param(cls, request): - return cls.query_param_name in request.query_params - - @classmethod - def get_raw_query(cls, request): - return request.query_params[cls.query_param_name] - - @classmethod - def get_parsed_query_from_req(cls, request): - raw_query = cls.get_raw_query(request) - parser = Parser(raw_query) - try: - parsed_query = parser.get_parsed() - return parsed_query - except SyntaxError as e: - msg = ( - "QueryFormatError: " + - e.msg + " on " + - e.text - ) - raise ValidationError(msg) from None - def get_allowed_fields(self): fields = super().fields if self.allowed_fields is not None: @@ -122,9 +134,10 @@ def include_fields(self): # The format is {nested_field: [sub_fields ...] ...} allowed_nested_fields = {} - # The self.query["include"] contains a list of allowed fields + # The self.parsed_restql_query["include"] + # contains a list of allowed fields, # The format is [field, {nested_field: [sub_fields ...]} ...] - included_fields = self.query["include"] + included_fields = self.parsed_restql_query["include"] include_all_fields = False for field in included_fields: if field == "*": @@ -172,9 +185,10 @@ def exclude_fields(self): # The format is {nested_field: [sub_fields ...] ...} allowed_nested_fields = {} - # The self.query["include"] contains a list of expanded nested fields + # The self.parsed_restql_query["include"] + # contains a list of expanded nested fields # The format is [{nested_field: [sub_field]} ...] - nested_fields = self.query["include"] + nested_fields = self.parsed_restql_query["include"] for field in nested_fields: if field == "*": # Ignore this since it's not an actual field(it's just a flag) @@ -192,8 +206,9 @@ def exclude_fields(self): ) allowed_nested_fields.update(field) - # self.query["exclude"] is a list of names of excluded fields - excluded_fields = self.query["exclude"] + # self.parsed_restql_query["exclude"] + # is a list of names of excluded fields + excluded_fields = self.parsed_restql_query["exclude"] for field in excluded_fields: self.is_field_found(field, all_field_names, raise_error=True) all_fields.pop(field) @@ -208,7 +223,7 @@ def fields(self): is_not_a_request_to_process = ( request is None or request.method != "GET" or - not self.has_query_param(request) + not self.has_restql_query_param(request) ) if is_not_a_request_to_process: @@ -225,42 +240,46 @@ def fields(self): ) if is_top_retrieve_request or is_top_list_request: - if self.query is None: - # Use a query from the request - self.query = self.get_parsed_query_from_req(request) + if self.parsed_restql_query is None: + # Use a parsed query from the request + self.parsed_restql_query = \ + self.get_parsed_restql_query_from_req(request) elif isinstance(self.parent, ListSerializer): field_name = self.parent.field_name parent = self.parent.parent if hasattr(parent, "nested_fields"): parent_nested_fields = parent.nested_fields - self.query = parent_nested_fields.get(field_name, None) + self.parsed_restql_query = \ + parent_nested_fields.get(field_name, None) elif isinstance(self.parent, Serializer): field_name = self.field_name parent = self.parent if hasattr(parent, "nested_fields"): parent_nested_fields = parent.nested_fields - self.query = parent_nested_fields.get(field_name, None) + self.parsed_restql_query = \ + parent_nested_fields.get(field_name, None) else: # Unkown scenario # No filtering of fields return self.get_allowed_fields() - if self.query is None: + if self.parsed_restql_query is None: # No filtering on nested fields # Retrieve all nested fields return self.get_allowed_fields() - # NOTE: self.query["include"] not being empty is not a guarantee - # that the exclude operator(-) has not been used because the same - # self.query["include"] is used to store nested fields when the - # exclude operator(-) is used - if self.query["exclude"]: + # NOTE: self.parsed_restql_query["include"] not being empty + # is not a guarantee that the exclude operator(-) has not been + # used because the same self.parsed_restql_query["include"] + # is used to store nested fields when the exclude operator(-) is used + if self.parsed_restql_query["exclude"]: # Exclude fields from a query return self.exclude_fields() - elif self.query["include"]: - # Here we are sure that self.query["exclude"] is empty - # which means the exclude operator(-) is not used, so - # self.query["include"] contains only fields to include + elif self.parsed_restql_query["include"]: + # Here we are sure that self.parsed_restql_query["exclude"] + # is empty which means the exclude operator(-) is not used, + # so self.parsed_restql_query["include"] contains only fields + # to include return self.include_fields() else: # The query is empty i.e query={} @@ -268,20 +287,16 @@ def fields(self): return {} -class EagerLoadingMixin(object): +class EagerLoadingMixin(RequestQueryParserMixin): @property - def parsed_query(self): + def parsed_restql_query(self): """ Gets parsed query for use in eager loading. Defaults to the serializer parsed query assuming using django-restql DynamicsFieldMixin. """ - if hasattr(self, "get_serializer_class"): - serializer_class = self.get_serializer_class() - - if issubclass(serializer_class, DynamicFieldsMixin): - if serializer_class.has_query_param(self.request): - return serializer_class.get_parsed_query_from_req(self.request) + if self.has_restql_query_param(self.request): + return self.get_parsed_restql_query_from_req(self.request) # Else include all fields query = { @@ -303,13 +318,13 @@ def get_prefetch_related_mapping(self): return {} @classmethod - def get_dict_parsed_query(cls, parsed_query): + def get_dict_parsed_restql_query(cls, parsed_restql_query): """ Returns the parsed query as a dict. """ keys = {} - include = parsed_query.get("include", []) - exclude = parsed_query.get("exclude", []) + include = parsed_restql_query.get("include", []) + exclude = parsed_restql_query.get("exclude", []) for item in include: if isinstance(item, str): @@ -317,7 +332,7 @@ def get_dict_parsed_query(cls, parsed_query): elif isinstance(item, dict): for key, nested_items in item.items(): key_base = key - nested_keys = cls.get_dict_parsed_query(nested_items) + nested_keys = cls.get_dict_parsed_restql_query(nested_items) keys[key_base] = nested_keys for item in exclude: @@ -326,12 +341,12 @@ def get_dict_parsed_query(cls, parsed_query): elif isinstance(item, dict): for key, nested_items in item.items(): key_base = key - nested_keys = cls.get_dict_parsed_query(nested_items) + nested_keys = cls.get_dict_parsed_restql_query(nested_items) keys[key_base] = nested_keys return keys @staticmethod - def get_related_fields(related_fields_mapping, dict_parsed_query): + def get_related_fields(related_fields_mapping, dict_parsed_restql_query): """ Returns only whitelisted related fields from a query to be used on `select_related` and `prefetch_related` @@ -342,7 +357,7 @@ def get_related_fields(related_fields_mapping, dict_parsed_query): if isinstance(related_field, str): related_field = [related_field] - query_node = dict_parsed_query + query_node = dict_parsed_restql_query for field in fields: if isinstance(query_node, dict): if field in query_node: @@ -366,7 +381,7 @@ def apply_eager_loading(self, queryset): Applies appropriate select_related and prefetch_related calls on a queryset """ - query = self.get_dict_parsed_query(self.parsed_query) + query = self.get_dict_parsed_restql_query(self.parsed_restql_query) select_mapping = self.get_select_related_mapping() prefetch_mapping = self.get_prefetch_related_mapping() @@ -391,7 +406,7 @@ def get_queryset(self): queryset = super().get_queryset() queryset = self.get_eager_queryset(queryset) return queryset - + class NestedCreateMixin(object): """ Create Mixin """ diff --git a/django_restql/settings.py b/django_restql/settings.py new file mode 100644 index 0000000..b836466 --- /dev/null +++ b/django_restql/settings.py @@ -0,0 +1,109 @@ +""" +Settings for Django RESTQL are all namespaced in the RESTQL setting. +For example your project's `settings.py` file might look like this: +RESTQL = { + 'QUERY_PARAM_NAME': 'query' +} +This module provides the `restql_settings` object, that is used to access +Django RESTQL settings, checking for user settings first, then falling +back to the defaults. +""" +from django.conf import settings +from django.test.signals import setting_changed +from django.utils.module_loading import import_string + + +DEFAULTS = { + 'QUERY_PARAM_NAME': 'query' +} + + +# List of settings that may be in string import notation. +IMPORT_STRINGS = [] + + +def perform_import(val, setting_name): + """ + If the given setting is a string import notation, + then perform the necessary import or imports. + """ + if val is None: + return None + elif isinstance(val, str): + return import_from_string(val, setting_name) + elif isinstance(val, (list, tuple)): + return [import_from_string(item, setting_name) for item in val] + return val + + +def import_from_string(val, setting_name): + """ + Attempt to import a class from a string representation. + """ + try: + return import_string(val) + except ImportError as e: + msg = ( + "Could not import '%s' for RESTQL setting '%s'. %s: %s." + ) % (val, setting_name, e.__class__.__name__, e) + raise ImportError(msg) + + +class RESTQLSettings: + """ + A settings object, that allows RESTQL settings to be accessed as properties. + For example: + from django_restql.settings import restql_settings + print(restql_settings.QUERY_PARAM_NAME) + Any setting with string import paths will be automatically resolved + and return the class, rather than the string literal. + """ + def __init__(self, user_settings=None, defaults=None, import_strings=None): + self.defaults = defaults or DEFAULTS + self.import_strings = import_strings or IMPORT_STRINGS + self._cached_attrs = set() + + @property + def user_settings(self): + if not hasattr(self, '_user_settings'): + self._user_settings = getattr(settings, 'RESTQL', {}) + return self._user_settings + + def __getattr__(self, attr): + if attr not in self.defaults: + raise AttributeError("Invalid RESTQL setting: '%s'" % attr) + + try: + # Check if present in user settings + val = self.user_settings[attr] + except KeyError: + # Fall back to defaults + val = self.defaults[attr] + + # Coerce import strings into classes + if attr in self.import_strings: + val = perform_import(val, attr) + + # Cache the result + self._cached_attrs.add(attr) + setattr(self, attr, val) + return val + + def reload(self): + for attr in self._cached_attrs: + delattr(self, attr) + self._cached_attrs.clear() + if hasattr(self, '_user_settings'): + delattr(self, '_user_settings') + + +restql_settings = RESTQLSettings(None, DEFAULTS, IMPORT_STRINGS) + + +def reload_restql_settings(*args, **kwargs): + setting = kwargs['setting'] + if setting == 'RESTQL': + restql_settings.reload() + + +setting_changed.connect(reload_restql_settings) \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 61c2ef9..304a7dc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -614,18 +614,18 @@ When prefetching with a `to_attr`, ensure that there are no collisions. Django d When prefetching *and* calling `select_related` on a field, Django may error, since the ORM does allow prefetching a selectable field, but not both at the same time. ### Changing `query` parameter name -If you don't want to use the name `query` as your parameter, you can inherit `DynamicFieldsMixin` and change it as shown below - +If you don't want to use the name `query` as your parameter, you can change it with`QUERY_PARAM_NAME` on settings file e.g ```py -from django_restql.mixins import DynamicFieldsMixin -class MyDynamicFieldMixin(DynamicFieldsMixin): - query_param_name = "your_favourite_name" +RESTQL = { + 'QUERY_PARAM_NAME' = "your_favourite_name" +} ``` - - Now you can use this Mixin on your serializer and use the name `your_favourite_name` as your parameter. E.g + Now you can use the name `your_favourite_name` as your query parameter. E.g `GET /users/?your_favourite_name={id, username}` +**Note:** Configuration for **django-restql** is all namespaced inside a single Django setting named `RESTQL`. + ## Mutating Data **django-restql** got your back on creating and updating nested data too, it has two components for mutating nested data, `NestedModelSerializer` and `NestedField`. A serializer `NestedModelSerializer` has `update` and `create` logics for nested fields on the other hand `NestedField` is used to validate data before dispatching update or create.