diff --git a/lib/pundit/policy_finder.rb b/lib/pundit/policy_finder.rb index 6072f369..615fa7ba 100644 --- a/lib/pundit/policy_finder.rb +++ b/lib/pundit/policy_finder.rb @@ -36,7 +36,7 @@ def scope # policy.update? #=> false # def policy - klass = find + klass = find(object) klass = klass.constantize if klass.is_a?(String) klass rescue NameError @@ -48,7 +48,7 @@ def policy # def scope! raise NotDefinedError, "unable to find policy scope of nil" if object.nil? - scope or raise NotDefinedError, "unable to find scope `#{find}::Scope` for `#{object.inspect}`" + scope or raise NotDefinedError, "unable to find scope `#{find(object)}::Scope` for `#{object.inspect}`" end # @return [Class] policy class with query methods @@ -56,7 +56,7 @@ def scope! # def policy! raise NotDefinedError, "unable to find policy of nil" if object.nil? - policy or raise NotDefinedError, "unable to find policy `#{find}` for `#{object.inspect}`" + policy or raise NotDefinedError, "unable to find policy `#{find(object)}` for `#{object.inspect}`" end # @return [String] the name of the key this object would have in a params hash @@ -73,19 +73,20 @@ def param_key private - def find - if object.nil? + def find(subject) + if subject.nil? nil - elsif object.respond_to?(:policy_class) - object.policy_class - elsif object.class.respond_to?(:policy_class) - object.class.policy_class + elsif subject.is_a?(Array) + modules = subject.dup + last = modules.pop + context = modules.map { |x| find_class_name(x) }.join("::") + [context, find(last)].join("::") + elsif subject.respond_to?(:policy_class) + subject.policy_class + elsif subject.class.respond_to?(:policy_class) + subject.class.policy_class else - klass = if object.is_a?(Array) - object.map { |x| find_class_name(x) }.join("::") - else - find_class_name(object) - end + klass = find_class_name(subject) "#{klass}#{SUFFIX}" end end diff --git a/spec/pundit_spec.rb b/spec/pundit_spec.rb index ad63d906..aa887565 100644 --- a/spec/pundit_spec.rb +++ b/spec/pundit_spec.rb @@ -134,6 +134,13 @@ expect(policy.post).to eq [:project, post] end + it "returns an instantiated policy given an array of a symbol and a model instance with policy_class override" do + policy = Pundit.policy(user, [:project, customer_post]) + expect(policy.class).to eq Project::PostPolicy + expect(policy.user).to eq user + expect(policy.post).to eq [:project, customer_post] + end + it "returns an instantiated policy given an array of a symbol and an active model instance" do policy = Pundit.policy(user, [:project, comment]) expect(policy.class).to eq Project::CommentPolicy @@ -155,6 +162,13 @@ expect(policy.post).to eq [:project, Comment] end + it "returns an instantiated policy given an array of a symbol and a class with policy_class override" do + policy = Pundit.policy(user, [:project, Customer::Post]) + expect(policy.class).to eq Project::PostPolicy + expect(policy.user).to eq user + expect(policy.post).to eq [:project, Customer::Post] + end + it "returns correct policy class for an array of a multi-word symbols" do policy = Pundit.policy(user, [:project_one_two_three, :criteria_four_five_six]) expect(policy.class).to eq ProjectOneTwoThree::CriteriaFourFiveSixPolicy diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb index c784c2f2..99608856 100644 --- a/spec/spec_helper.rb +++ b/spec/spec_helper.rb @@ -77,9 +77,13 @@ def model_name OpenStruct.new(param_key: "customer_post") end - def policy_class + def self.policy_class PostPolicy end + + def policy_class + self.class.policy_class + end end end