Skip to content

Commit

Permalink
Kafka: authenticate with Event Hubs using OAuth2 (close #57)
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Feb 15, 2024
1 parent 6cdcbac commit 4c4b5d7
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 10 deletions.
6 changes: 6 additions & 0 deletions modules/kafka/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ snowplow.defaults {
"group.id": null # invalid value MUST be overridden by the applicaion
"allow.auto.create.topics": "false"
"auto.offset.reset": "latest"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
}
}
Expand All @@ -13,6 +16,9 @@ snowplow.defaults {
kafka: {
producerConf: {
"client.id": null # invalid value MUST be overriden by the application
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
maxRecordSize: 1000000
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.azure

import java.net.URI
import java.{lang, util}

import com.nimbusds.jwt.JWTParser

import javax.security.auth.callback.Callback
import javax.security.auth.callback.UnsupportedCallbackException
import javax.security.auth.login.AppConfigurationEntry

import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback

import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.core.credential.TokenRequestContext

trait AzureAuthenticationCallbackHandler extends AuthenticateCallbackHandler {

val credentials = new DefaultAzureCredentialBuilder().build()

var sbUri: String = ""

override def configure(
configs: util.Map[String, _],
saslMechanism: String,
jaasConfigEntries: util.List[AppConfigurationEntry]
): Unit = {
val bootstrapServer =
configs
.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG)
.toString
.replaceAll("\\[|\\]", "")
.split(",")
.toList
.headOption match {
case Some(s) => s
case None => throw new Exception("Empty bootstrap servers list")
}
val uri = URI.create("https://" + bootstrapServer)
// Workload identity works with '.default' scope
this.sbUri = s"${uri.getScheme}://${uri.getHost}/.default"
}

override def handle(callbacks: Array[Callback]): Unit =
callbacks.foreach {
case callback: OAuthBearerTokenCallback =>
val token = getOAuthBearerToken()
callback.token(token)
case callback => throw new UnsupportedCallbackException(callback)
}

def getOAuthBearerToken(): OAuthBearerToken = {
val reqContext = new TokenRequestContext()
reqContext.addScopes(sbUri)
val accessToken = credentials.getTokenSync(reqContext).getToken
val jwt = JWTParser.parse(accessToken)
val claims = jwt.getJWTClaimsSet

new OAuthBearerToken {
override def value(): String = accessToken

override def lifetimeMs(): Long = claims.getExpirationTime.getTime

override def scope(): util.Set[String] = null

override def principalName(): String = null

override def startTimeMs(): lang.Long = null
}
}

override def close(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@ import cats.Monad
import com.snowplowanalytics.snowplow.sinks.{Sink, Sinkable}
import fs2.kafka._

import scala.reflect._

import java.util.UUID

import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler

object KafkaSink {

def resource[F[_]: Async](config: KafkaSinkConfig): Resource[F, Sink[F]] = {
def resource[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSinkConfig,
authHandlerClass: ClassTag[T]
): Resource[F, Sink[F]] = {
val producerSettings =
ProducerSettings[F, String, Array[Byte]]
.withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName)
.withBootstrapServers(config.bootstrapServers)
.withProperties(config.producerConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import fs2.Stream
import org.typelevel.log4cats.{Logger, SelfAwareStructuredLogger}
import org.typelevel.log4cats.slf4j.Slf4jLogger

import scala.reflect._

import java.nio.ByteBuffer
import java.time.Instant

Expand All @@ -26,20 +28,27 @@ import org.apache.kafka.common.TopicPartition
// snowplow
import com.snowplowanalytics.snowplow.sources.SourceAndAck
import com.snowplowanalytics.snowplow.sources.internal.{Checkpointer, LowLevelEvents, LowLevelSource}
import com.snowplowanalytics.snowplow.azure.AzureAuthenticationCallbackHandler

object KafkaSource {

private implicit def logger[F[_]: Sync]: SelfAwareStructuredLogger[F] = Slf4jLogger.getLogger[F]

def build[F[_]: Async](config: KafkaSourceConfig): F[SourceAndAck[F]] =
LowLevelSource.toSourceAndAck(lowLevel(config))
def build[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): F[SourceAndAck[F]] =
LowLevelSource.toSourceAndAck(lowLevel(config, authHandlerClass))

private def lowLevel[F[_]: Async](config: KafkaSourceConfig): LowLevelSource[F, KafkaCheckpoints[F]] =
private def lowLevel[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): LowLevelSource[F, KafkaCheckpoints[F]] =
new LowLevelSource[F, KafkaCheckpoints[F]] {
def checkpointer: Checkpointer[F, KafkaCheckpoints[F]] = kafkaCheckpointer

def stream: Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
kafkaStream(config)
kafkaStream(config, authHandlerClass)
}

case class OffsetAndCommit[F[_]](offset: Long, commit: F[Unit])
Expand All @@ -59,9 +68,12 @@ object KafkaSource {
def nack(c: KafkaCheckpoints[F]): F[Unit] = Applicative[F].unit
}

private def kafkaStream[F[_]: Async](config: KafkaSourceConfig): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
private def kafkaStream[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): Stream[F, Stream[F, LowLevelEvents[KafkaCheckpoints[F]]]] =
KafkaConsumer
.stream(consumerSettings[F](config))
.stream(consumerSettings[F, T](config, authHandlerClass))
.evalTap(_.subscribeTo(config.topicName))
.flatMap { consumer =>
consumer.partitionsMapStream
Expand Down Expand Up @@ -124,8 +136,12 @@ object KafkaSource {
private implicit def byteBufferDeserializer[F[_]: Sync]: Resource[F, ValueDeserializer[F, ByteBuffer]] =
Resource.pure(Deserializer.lift(arr => Sync[F].pure(ByteBuffer.wrap(arr))))

private def consumerSettings[F[_]: Async](config: KafkaSourceConfig): ConsumerSettings[F, Array[Byte], ByteBuffer] =
private def consumerSettings[F[_]: Async, T <: AzureAuthenticationCallbackHandler](
config: KafkaSourceConfig,
authHandlerClass: ClassTag[T]
): ConsumerSettings[F, Array[Byte], ByteBuffer] =
ConsumerSettings[F, Array[Byte], ByteBuffer]
.withProperty("sasl.login.callback.handler.class", authHandlerClass.runtimeClass.getName)
.withBootstrapServers(config.bootstrapServers)
.withProperties(config.consumerConf)
.withEnableAutoCommit(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class KafkaSinkConfigSpec extends Specification {
topicName = "my-topic",
bootstrapServers = "my-bootstrap-server:9092",
producerConf = Map(
"client.id" -> "my-client-id"
"client.id" -> "my-client-id",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class KafkaSourceConfigSpec extends Specification {
consumerConf = Map(
"group.id" -> "my-consumer-group",
"allow.auto.create.topics" -> "false",
"auto.offset.reset" -> "latest"
"auto.offset.reset" -> "latest",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)

Expand Down
1 change: 1 addition & 0 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ object Dependencies {
fs2Kafka,
circeConfig,
circeGeneric,
azureIdentity,
snappy,
specs2
)
Expand Down

0 comments on commit 4c4b5d7

Please sign in to comment.