diff --git a/lib/pundit.rb b/lib/pundit.rb index 5eab334f..e234ea92 100644 --- a/lib/pundit.rb +++ b/lib/pundit.rb @@ -8,6 +8,7 @@ require "active_support/core_ext/module/introspection" require "active_support/dependencies/autoload" require "pundit/authorization" +require "pundit/context" # @api private # To avoid name clashes with common Error naming when mixing in Pundit, @@ -64,104 +65,29 @@ def self.included(base) end class << self - # Retrieves the policy for the given record, initializing it with the - # record and user and finally throwing an error if the user is not - # authorized to perform the given action. - # - # @param user [Object] the user that initiated the action - # @param possibly_namespaced_record [Object, Array] the object we're checking permissions of - # @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`) - # @param policy_class [Class] the policy class we want to force use of - # @param cache [#[], #[]=] a Hash-like object to cache the found policy instance in - # @raise [NotAuthorizedError] if the given query method returned false - # @return [Object] Always returns the passed object record - def authorize(user, possibly_namespaced_record, query, policy_class: nil, cache: {}) - record = pundit_model(possibly_namespaced_record) - policy = if policy_class - policy_class.new(user, record) - else - cache[possibly_namespaced_record] ||= policy!(user, possibly_namespaced_record) - end - - raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query) - - record + # @see [Pundit::Context#authorize] + def authorize(user, record, query, policy_class: nil, cache: {}) + Context.new(user: user, policy_cache: cache).authorize(record, query: query, policy_class: policy_class) end - # Retrieves the policy scope for the given record. - # - # @see https://github.com/varvet/pundit#scopes - # @param user [Object] the user that initiated the action - # @param scope [Object] the object we're retrieving the policy scope for - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope - def policy_scope(user, scope) - policy_scope_class = PolicyFinder.new(scope).scope - return unless policy_scope_class - - begin - policy_scope = policy_scope_class.new(user, pundit_model(scope)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" - end - - policy_scope.resolve + # @see [Pundit::Context#policy_scope] + def policy_scope(user, *args, **kwargs, &block) + Context.new(user: user).policy_scope(*args, **kwargs, &block) end - # Retrieves the policy scope for the given record. - # - # @see https://github.com/varvet/pundit#scopes - # @param user [Object] the user that initiated the action - # @param scope [Object] the object we're retrieving the policy scope for - # @raise [NotDefinedError] if the policy scope cannot be found - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Scope{#resolve}] instance of scope class which can resolve to a scope - def policy_scope!(user, scope) - policy_scope_class = PolicyFinder.new(scope).scope! - return unless policy_scope_class - - begin - policy_scope = policy_scope_class.new(user, pundit_model(scope)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" - end - - policy_scope.resolve + # @see [Pundit::Context#policy_scope!] + def policy_scope!(user, *args, **kwargs, &block) + Context.new(user: user).policy_scope!(*args, **kwargs, &block) end - # Retrieves the policy for the given record. - # - # @see https://github.com/varvet/pundit#policies - # @param user [Object] the user that initiated the action - # @param record [Object] the object we're retrieving the policy for - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Object, nil] instance of policy class with query methods - def policy(user, record) - policy = PolicyFinder.new(record).policy - policy&.new(user, pundit_model(record)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" + # @see [Pundit::Context#policy] + def policy(user, *args, **kwargs, &block) + Context.new(user: user).policy(*args, **kwargs, &block) end - # Retrieves the policy for the given record. - # - # @see https://github.com/varvet/pundit#policies - # @param user [Object] the user that initiated the action - # @param record [Object] the object we're retrieving the policy for - # @raise [NotDefinedError] if the policy cannot be found - # @raise [InvalidConstructorError] if the policy constructor called incorrectly - # @return [Object] instance of policy class with query methods - def policy!(user, record) - policy = PolicyFinder.new(record).policy! - policy.new(user, pundit_model(record)) - rescue ArgumentError - raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" - end - - private - - def pundit_model(record) - record.is_a?(Array) ? record.last : record + # @see [Pundit::Context#policy!] + def policy!(user, *args, **kwargs, &block) + Context.new(user: user).policy!(*args, **kwargs, &block) end end diff --git a/lib/pundit/authorization.rb b/lib/pundit/authorization.rb index 1231f2a7..b47cb20c 100644 --- a/lib/pundit/authorization.rb +++ b/lib/pundit/authorization.rb @@ -15,6 +15,14 @@ module Authorization protected + # @return [Pundit::Context] a new instance of {Pundit::Context} with the current user + def pundit + @pundit ||= Pundit::Context.new( + user: pundit_user, + policy_cache: policies + ) + end + # @return [Boolean] whether authorization has been performed, i.e. whether # one {#authorize} or {#skip_authorization} has been called def pundit_policy_authorized? @@ -64,7 +72,7 @@ def authorize(record, query = nil, policy_class: nil) @_pundit_policy_authorized = true - Pundit.authorize(pundit_user, record, query, policy_class: policy_class, cache: policies) + pundit.authorize(record, query: query, policy_class: policy_class) end # Allow this action not to perform authorization. @@ -100,7 +108,7 @@ def policy_scope(scope, policy_scope_class: nil) # @param record [Object] the object we're retrieving the policy for # @return [Object, nil] instance of policy class with query methods def policy(record) - policies[record] ||= Pundit.policy!(pundit_user, record) + policies[record] ||= pundit.policy!(record) end # Retrieves a set of permitted attributes from the policy by instantiating @@ -162,7 +170,7 @@ def pundit_user private def pundit_policy_scope(scope) - policy_scopes[scope] ||= Pundit.policy_scope!(pundit_user, scope) + policy_scopes[scope] ||= pundit.policy_scope!(scope) end end end diff --git a/lib/pundit/context.rb b/lib/pundit/context.rb new file mode 100644 index 00000000..9e8d6078 --- /dev/null +++ b/lib/pundit/context.rb @@ -0,0 +1,118 @@ +# frozen_string_literal: true + +module Pundit + class Context + def initialize(user:, policy_cache: {}) + @user = user + @policy_cache = policy_cache + end + + attr_reader :user + + # @api private + attr_reader :policy_cache + + # Retrieves the policy for the given record, initializing it with the + # record and user and finally throwing an error if the user is not + # authorized to perform the given action. + # + # @param user [Object] the user that initiated the action + # @param possibly_namespaced_record [Object, Array] the object we're checking permissions of + # @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`) + # @param policy_class [Class] the policy class we want to force use of + # @raise [NotAuthorizedError] if the given query method returned false + # @return [Object] Always returns the passed object record + def authorize(possibly_namespaced_record, query:, policy_class:) + record = pundit_model(possibly_namespaced_record) + policy = if policy_class + policy_class.new(user, record) + else + policy_cache[possibly_namespaced_record] ||= policy!(possibly_namespaced_record) + end + + raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query) + + record + end + + # Retrieves the policy scope for the given record. + # + # @see https://github.com/varvet/pundit#scopes + # @param user [Object] the user that initiated the action + # @param scope [Object] the object we're retrieving the policy scope for + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope + def policy_scope(scope) + policy_scope_class = policy_finder(scope).scope + return unless policy_scope_class + + begin + policy_scope = policy_scope_class.new(user, pundit_model(scope)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" + end + + policy_scope.resolve + end + + # Retrieves the policy scope for the given record. + # + # @see https://github.com/varvet/pundit#scopes + # @param user [Object] the user that initiated the action + # @param scope [Object] the object we're retrieving the policy scope for + # @raise [NotDefinedError] if the policy scope cannot be found + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Scope{#resolve}] instance of scope class which can resolve to a scope + def policy_scope!(scope) + policy_scope_class = policy_finder(scope).scope! + return unless policy_scope_class + + begin + policy_scope = policy_scope_class.new(user, pundit_model(scope)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called" + end + + policy_scope.resolve + end + + # Retrieves the policy for the given record. + # + # @see https://github.com/varvet/pundit#policies + # @param user [Object] the user that initiated the action + # @param record [Object] the object we're retrieving the policy for + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Object, nil] instance of policy class with query methods + def policy(record) + policy = policy_finder(record).policy + policy&.new(user, pundit_model(record)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" + end + + # Retrieves the policy for the given record. + # + # @see https://github.com/varvet/pundit#policies + # @param user [Object] the user that initiated the action + # @param record [Object] the object we're retrieving the policy for + # @raise [NotDefinedError] if the policy cannot be found + # @raise [InvalidConstructorError] if the policy constructor called incorrectly + # @return [Object] instance of policy class with query methods + def policy!(record) + policy = policy_finder(record).policy! + policy.new(user, pundit_model(record)) + rescue ArgumentError + raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called" + end + + private + + def policy_finder(record) + PolicyFinder.new(record) + end + + def pundit_model(record) + record.is_a?(Array) ? record.last : record + end + end +end