Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed: jsonpath lib used as ClaimName #453

Merged
merged 3 commits into from May 29, 2019
Merged
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -78,7 +78,7 @@ class JwtAuthRule(val settings: JwtAuthRule.Settings)
case Left(_) =>
Task.now(Rejected)
case Right((user, groups)) =>
logger.debug(s"JWT resolved user for claim ${settings.jwt.userClaim}: $user, and groups for claim ${settings.jwt.groupsClaim}: ${groups}")
logger.debug(s"JWT resolved user for claim ${settings.jwt.userClaim}: $user, and groups for claim ${settings.jwt.groupsClaim}: $groups")
val claimProcessingResult = for {
newBlockContext <- handleUserClaimSearchResult(blockContext, user)
finalBlockContext <- handleGroupsClaimSearchResult(newBlockContext, groups)
@@ -35,6 +35,7 @@ import tech.beshu.ror.acl.request.RequestContextOps._
import tech.beshu.ror.acl.show.logs._
import tech.beshu.ror.acl.utils.ClaimsOps.ClaimSearchResult.{Found, NotFound}
import tech.beshu.ror.acl.utils.ClaimsOps._
import tech.beshu.ror.com.jayway.jsonpath.JsonPath

import scala.collection.SortedSet
import scala.util.Try
@@ -139,6 +140,6 @@ object RorKbnAuthRule {

final case class Settings(rorKbn: RorKbnDef, groups: Set[Group])

private val userClaimName = ClaimName(NonEmptyString.unsafeFrom("user"))
private val groupsClaimName = ClaimName(NonEmptyString.unsafeFrom("groups"))
private val userClaimName = ClaimName(JsonPath.compile("user"))
private val groupsClaimName = ClaimName(JsonPath.compile("groups"))
}
@@ -29,6 +29,7 @@ import monix.eval.Task
import org.apache.logging.log4j.scala.Logging
import tech.beshu.ror.acl.header.ToHeaderValue
import tech.beshu.ror.Constants
import tech.beshu.ror.com.jayway.jsonpath.JsonPath

import scala.util.Try

@@ -229,7 +230,17 @@ object domain {
}


final case class ClaimName(value: NonEmptyString)
final case class ClaimName(name: JsonPath) {

override def equals(other: Any): Boolean = {
other match {
case that: ClaimName => that.name.getPath.equals(this.name.getPath)
case _ => false
}
}

override def hashCode: Int = name.getPath.hashCode
}

final case class JwtToken(value: NonEmptyString)

@@ -29,23 +29,26 @@ import eu.timepit.refined.numeric.Positive
import eu.timepit.refined.refineV
import eu.timepit.refined.types.string.NonEmptyString
import io.circe.Decoder
import org.apache.logging.log4j.scala.Logging
import tech.beshu.ror.acl.domain.{Address, Group, Header, User}
import tech.beshu.ror.acl.blocks.Value
import tech.beshu.ror.acl.blocks.Value.ConvertError
import tech.beshu.ror.acl.factory.CoreFactory.AclCreationError.Reason.Message
import tech.beshu.ror.acl.factory.CoreFactory.AclCreationError.ValueLevelCreationError
import tech.beshu.ror.acl.factory.CoreFactory.AclCreationError.{DefinitionsLevelCreationError, ValueLevelCreationError}
import tech.beshu.ror.acl.factory.HttpClientsFactory
import tech.beshu.ror.acl.factory.decoders.definitions.ExternalAuthorizationServicesDecoder.logger
import tech.beshu.ror.acl.orders._
import tech.beshu.ror.acl.refined._
import tech.beshu.ror.acl.utils.CirceOps._
import tech.beshu.ror.acl.utils.ScalaOps._
import tech.beshu.ror.acl.utils.SyncDecoderCreator
import tech.beshu.ror.com.jayway.jsonpath.JsonPath

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}

object common {
object common extends Logging {

implicit val nonEmptyStringDecoder: Decoder[NonEmptyString] =
Decoder
@@ -222,6 +225,21 @@ object common {
)
}

implicit val jsonPathDecoder: Decoder[JsonPath] =
SyncDecoderCreator
.from(Decoder.decodeString)
.emapE[JsonPath] { jsonPathStr =>
Try(JsonPath.compile(jsonPathStr))
.toEither
.left
.map { ex =>
logger.error("JSON path compilation failed", ex)
DefinitionsLevelCreationError(Message(s"Cannot compile '$jsonPathStr' to JSON path"))
}
}
.decoder


private lazy val finiteDurationStringDecoder: Decoder[FiniteDuration] =
DecoderHelpers
.decodeStringLike
@@ -75,20 +75,6 @@ object ExternalAuthorizationServicesDecoder extends Logging {
}
.decoder

private implicit val jsonPathDecoder: Decoder[JsonPath] =
SyncDecoderCreator
.from(Decoder.decodeString)
.emapE[JsonPath] { jsonPathStr =>
Try(JsonPath.compile(jsonPathStr))
.toEither
.left
.map { ex =>
logger.error("JSON path compilation failed", ex)
DefinitionsLevelCreationError(Message(s"Cannot compile '$jsonPathStr' to JSON path"))
}
}
.decoder

private implicit val headerSetDecoder: Decoder[Set[Header]] =
decoderTupleListDecoder.map(_.map(Header.apply).toSet)

@@ -42,7 +42,7 @@ object JwtDefinitionsDecoder {

implicit val jwtDefNameDecoder: Decoder[Name] = DecoderHelpers.decodeStringLikeNonEmpty.map(Name.apply)

private implicit val claimDecoder: Decoder[ClaimName] = DecoderHelpers.decodeStringLikeNonEmpty.map(ClaimName.apply)
private implicit val claimDecoder: Decoder[ClaimName] = jsonPathDecoder.map(ClaimName.apply)

private def jwtDefDecoder(implicit httpClientFactory: HttpClientsFactory,
resolver: StaticVariablesResolver): Decoder[JwtDef] = {
@@ -16,18 +16,18 @@
*/
package tech.beshu.ror.acl.utils

import java.util

import eu.timepit.refined.types.string.NonEmptyString
import io.jsonwebtoken.Claims
import org.apache.logging.log4j.scala.Logging
import tech.beshu.ror.acl.domain.{ClaimName, Group, Header, User}
import tech.beshu.ror.acl.utils.ClaimsOps.ClaimSearchResult
import tech.beshu.ror.acl.utils.ClaimsOps.ClaimSearchResult._

import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.language.{implicitConversions, postfixOps}
import scala.util.Try

class ClaimsOps(val claims: Claims) extends AnyVal {
class ClaimsOps(val claims: Claims) extends Logging {

def headerNameClaim(name: Header.Name): ClaimSearchResult[Header] = {
Option(claims.get(name.value.value, classOf[String]))
@@ -37,39 +37,43 @@ class ClaimsOps(val claims: Claims) extends AnyVal {
}
}

def userIdClaim(name: ClaimName): ClaimSearchResult[User.Id] = {
Option(claims.get(name.value.value, classOf[String])) match {
case Some(id) => Found(User.Id(id))
case None => NotFound
}
def userIdClaim(claimName: ClaimName): ClaimSearchResult[User.Id] = {
Try(claimName.name.read[Any](claims))
.map {
case value: String => Found(User.Id(value))
case _ => NotFound
}
.fold(
ex => {
logger.debug("JsonPath reading exception", ex)
NotFound
},
identity
)
}

// todo: use json path (with jackson? or maybe we can convert java map to json?)
def groupsClaim(name: ClaimName): ClaimSearchResult[Set[Group]] = {
val result = name.value.value.split("[.]").toList match {
case Nil | _ :: Nil =>
Option(claims.get(name.value.value, classOf[Object]))
case path :: restPaths =>
restPaths.foldLeft(Option(claims.get(path, classOf[Object]))) {
case (None, _) => None
case (Some(value), currentPath) =>
value match {
case map: util.Map[String, Object] =>
Option(map.get(currentPath))
case _ =>
Some(value)
}
}
}
result match {
case Some(value: String) =>
Found(toGroup(value).map(Set(_)).getOrElse(Set.empty))
case Some(values) if values.isInstanceOf[util.Collection[String]] =>
val collection = values.asInstanceOf[util.Collection[String]]
Found(collection.asScala.toList.flatMap(toGroup).toSet)
case _ =>
NotFound
}
def groupsClaim(claimName: ClaimName): ClaimSearchResult[Set[Group]] = {
Try(claimName.name.read[Any](claims))
.map {
case value: String =>
Found((value :: Nil).flatMap(toGroup).toSet)
case collection: java.util.Collection[_] =>
Found {
collection.asScala
.collect { case value: String => value }
.flatMap(toGroup)
.toSet
}
case _ =>
NotFound
}
.fold(
ex => {
logger.debug("JsonPath reading exception", ex)
NotFound
},
identity
)
}

private def toGroup(value: String) = {
@@ -32,6 +32,7 @@ import tech.beshu.ror.acl.blocks.rules.JwtAuthRule
import tech.beshu.ror.acl.blocks.rules.Rule.RuleResult.{Fulfilled, Rejected}
import tech.beshu.ror.acl.blocks.{BlockContext, RequestContextInitiatedBlockContext}
import tech.beshu.ror.acl.domain._
import tech.beshu.ror.com.jayway.jsonpath.JsonPath
import tech.beshu.ror.mocks.MockRequestContext
import tech.beshu.ror.utils.TestsUtils
import tech.beshu.ror.utils.TestsUtils._
@@ -106,7 +107,7 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = None
),
tokenHeader = Header(
@@ -126,8 +127,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("groups")))
),
tokenHeader = Header(
Header.Name.authorization,
@@ -146,15 +147,42 @@ class JwtAuthRuleTests
)(blockContext)
}
}
"groups claim name is defined as http address and groups are passed in JWT token claim" in {
val key: Key = Keys.secretKeyFor(SignatureAlgorithm.valueOf("HS256"))
assertMatchRule(
configuredJwtDef = JwtDef(
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("https://{domain}/claims/roles")))
),
tokenHeader = Header(
Header.Name.authorization,
{
val jwtBuilder = Jwts.builder
.signWith(key)
.setSubject("test")
.claim("userId", "user1")
.claim("https://{domain}/claims/roles", List("group1", "group2").asJava)
NonEmptyString.unsafeFrom(s"Bearer ${jwtBuilder.compact}")
}
)
) {
blockContext => assertBlockContext(
loggedUser = Some(LoggedUser(User.Id("user1")))
)(blockContext)
}
}
"groups claim name is defined and no groups field is passed in JWT token claim" in {
val key: Key = Keys.secretKeyFor(SignatureAlgorithm.valueOf("HS256"))
assertMatchRule(
configuredJwtDef = JwtDef(
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("groups")))
),
configuredGroups = Set.empty,
tokenHeader = Header(
@@ -180,8 +208,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("tech.beshu.groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("tech.beshu.groups")))
),
tokenHeader = Header(
Header.Name.authorization,
@@ -207,8 +235,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("groups")))
),
configuredGroups = Set(groupFrom("group3"), groupFrom("group2")),
tokenHeader = Header(
@@ -325,7 +353,7 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = None
),
tokenHeader = Header(
@@ -341,8 +369,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("groups")))
),
tokenHeader = Header(
Header.Name.authorization,
@@ -357,8 +385,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("tech.beshu.groups.subgroups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("tech.beshu.groups.subgroups")))
),
configuredGroups = Set(Group("group1".nonempty)),
tokenHeader = Header(
@@ -381,8 +409,8 @@ class JwtAuthRuleTests
JwtDef.Name("test".nonempty),
AuthorizationTokenDef(Header.Name.authorization, "Bearer "),
SignatureCheckMethod.Hmac(key.getEncoded),
userClaim = Some(ClaimName("userId".nonempty)),
groupsClaim = Some(ClaimName("groups".nonempty))
userClaim = Some(ClaimName(JsonPath.compile("userId"))),
groupsClaim = Some(ClaimName(JsonPath.compile("groups")))
),
configuredGroups = Set(groupFrom("group3"), groupFrom("group4")),
tokenHeader = Header(
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.