Skip to content
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % Provi
libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % Provided
libraryDependencies += "org.apache.spark" %% "spark-hive" % sparkVersion % Provided
libraryDependencies += "com.databricks" % "dbutils-api_2.12" % "0.0.5" % Provided
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595" % Provided
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595"
libraryDependencies += "com.amazonaws" % "aws-java-sdk-secretsmanager" % "1.11.595"
libraryDependencies += "io.delta" % "delta-core_2.12" % "1.0.0" % Provided
libraryDependencies += "org.scalaj" %% "scalaj-http" % "2.4.2"
//libraryDependencies += "org.apache.hive" % "hive-metastore" % "2.3.9"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,25 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw
override def deserialize(jp: JsonParser, ctxt: DeserializationContext): OverwatchParams = {
val masterNode = jp.getCodec.readTree[JsonNode](jp)

val token = try {
Some(TokenSecret(
masterNode.get("tokenSecret").get("scope").asText(),
masterNode.get("tokenSecret").get("key").asText()))
} catch {
case e: Throwable =>
println("No Token Secret Defined", e)
None
// TODO: consider keeping enum with specific secrets inner structure and below
// transform to function processing the enum in a loop
val token = {

val databricksToken =
for {
scope <- getOptionString(masterNode,"tokenSecret.scope")
key <- getOptionString(masterNode, "tokenSecret.key")
} yield TokenSecret(scope, key)

val finalToken = if (databricksToken.isEmpty)
for {
secretId <- getOptionString(masterNode,"tokenSecret.secretId")
region <- getOptionString(masterNode,"tokenSecret.region")
apiToken <- getOptionString(masterNode,"tokenSecret.tokenKey")
} yield AwsTokenSecret(secretId, region, apiToken)
else databricksToken

finalToken
}

val rawAuditPath = getOptionString(masterNode, "auditLogConfig.rawAuditPath")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class Initializer(config: Config) extends SparkSessionWrapper {
config.setExternalizeOptimize(rawParams.externalizeOptimize)

val overwatchScope = rawParams.overwatchScope.getOrElse(Seq("all"))
val tokenSecret = rawParams.tokenSecret

// TODO -- PRIORITY -- If data target is null -- default table gets dbfs:/null
val dataTarget = rawParams.dataTarget.getOrElse(
DataTarget(Some("overwatch"), Some("dbfs:/user/hive/warehouse/overwatch.db"), None))
Expand All @@ -275,24 +275,30 @@ class Initializer(config: Config) extends SparkSessionWrapper {
if (overwatchScope.head == "all") config.setOverwatchScope(config.orderedOverwatchScope)
else config.setOverwatchScope(validateScope(overwatchScope))

// validate token secret requirements
// TODO - Validate if token has access to necessary assets. Warn/Fail if not
if (tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
if (tokenSecret.get.scope.isEmpty || tokenSecret.get.key.isEmpty) {
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
s"Either supply both or neither.")
if (rawParams.tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
rawParams.tokenSecret.map {
case databricksSecret: TokenSecret =>
// validate token secret requirements
// TODO - Validate if databricks token has access to necessary assets. Warn/Fail if not

if (databricksSecret.scope.isEmpty || databricksSecret.key.isEmpty) {
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
s"Either supply both or neither.")
}
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == databricksSecret.scope)
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${databricksSecret.scope} does not exist " +
s"in this workspace. Please provide a scope available and accessible to this account.")
val scopeName = scopeCheck.head

val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == databricksSecret.key)
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${databricksSecret.key} does not exist " +
s"within the provided scope: ${databricksSecret.scope}. Please provide a scope and key " +
s"available and accessible to this account.")

config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))

case awsSecret: AwsTokenSecret => config.registerWorkspaceMeta(Some(awsSecret))
}
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == tokenSecret.get.scope)
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${tokenSecret.get.scope} does not exist " +
s"in this workspace. Please provide a scope available and accessible to this account.")
val scopeName = scopeCheck.head

val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == tokenSecret.get.key)
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${tokenSecret.get.key} does not exist " +
s"within the provided scope: ${tokenSecret.get.scope}. Please provide a scope and key " +
s"available and accessible to this account.")

config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))
} else config.registerWorkspaceMeta(None)

// Validate data Target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.labs.overwatch.pipeline.TransformFunctions._
import com.databricks.labs.overwatch.utils._
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

class Module(
val moduleId: Int,
Expand Down Expand Up @@ -172,10 +173,27 @@ class Module(
initState
}

private def normalizeToken(secretToken: TokenSecret, reportDf: DataFrame): DataFrame = {
val inputConfigCols = reportDf.select($"inputConfig.*")
.columns
.filter(_!="tokenSecret")
.map(name => col("inputConfig."+name))

reportDf
.withColumn(
"inputConfig",
struct(inputConfigCols:+struct(lit(secretToken.scope),lit(secretToken.key)).as("tokenSecret"):_*)
)
}

private def finalizeModule(report: ModuleStatusReport): Unit = {
pipeline.updateModuleState(report.simple)
if (!pipeline.readOnly) {
pipeline.database.write(Seq(report).toDF, pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
val secretToken = SecretTools(report.inputConfig.tokenSecret.get).getTargetTableStruct
val targetDf = normalizeToken(secretToken, Seq(report).toDF)
pipeline.database.write(
targetDf,
pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.databricks.labs.overwatch.utils

import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest
import org.apache.log4j.{Level, Logger}
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.parse

import java.util.Base64

object AwsSecrets {
private val logger: Logger = Logger.getLogger(this.getClass)

def readApiToken(secretId: String, region: String, apiTokenKey: String = "apiToken"): String = {
secretValueAsMap(secretId, region)
.getOrElse(apiTokenKey ,throw new IllegalStateException("apiTokenKey param not found"))
.asInstanceOf[String]
}

def secretValueAsMap(secretId: String, region: String = "us-east-2"): Map[String, Any] =
parseJsonToMap(readRawSecretFromAws(secretId,region))

def readRawSecretFromAws(secretId: String, region: String): String = {
logger.log(Level.INFO,s"Looking up secret $secretId in AWS Secret Manager")

val secretsClient = AWSSecretsManagerClientBuilder
.standard()
.withRegion(region)
.build()
val request = new GetSecretValueRequest().withSecretId(secretId)
val secretValue = secretsClient.getSecretValue(request)

if (secretValue.getSecretString != null)
secretValue.getSecretString
else
new String(Base64.getDecoder.decode(secretValue.getSecretBinary).array)
}

def parseJsonToMap(jsonStr: String): Map[String, Any] = {
implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
parse(jsonStr).extract[Map[String, Any]]
}
}
12 changes: 3 additions & 9 deletions src/main/scala/com/databricks/labs/overwatch/utils/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -312,20 +312,14 @@ class Config() {
* as the job owner or notebook user (if called from notebook)
* @return
*/
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecret]): this.type = {
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecretContainer]): this.type = {
var rawToken = ""
var scope = ""
var key = ""
try {
// Token secrets not supported in local testing
if (tokenSecret.nonEmpty && !_isLocalTesting) { // not local testing and secret passed
_workspaceUrl = dbutils.notebook.getContext().apiUrl.get
_cloudProvider = if (_workspaceUrl.toLowerCase().contains("azure")) "azure" else "aws"
scope = tokenSecret.get.scope
key = tokenSecret.get.key
rawToken = dbutils.secrets.get(scope, key)
val authMessage = s"Valid Secret Identified: Executing with token located in secret, $scope : $key"
logger.log(Level.INFO, authMessage)
rawToken = SecretTools(tokenSecret.get).getApiToken
_tokenType = "Secret"
} else {
if (_isLocalTesting) { // Local testing env vars
Expand All @@ -344,7 +338,7 @@ class Config() {
}
}
if (!rawToken.matches("^(dapi|dkea)[a-zA-Z0-9-]*$")) throw new BadConfigException(s"contents of secret " +
s"at scope:key $scope:$key is not in a valid format. Please validate the contents of your secret. It must be " +
s"is not in a valid format. Please validate the contents of your secret. It must be " +
s"a user access token. It should start with 'dapi' ")
setApiEnv(ApiEnv(isLocalTesting, workspaceURL, rawToken, packageVersion))
this
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.databricks.labs.overwatch.utils

import com.databricks.dbutils_v1.DBUtilsHolder.dbutils
import org.apache.log4j.{Level, Logger}

/**
* SecretTools handles common functionality related to working with secrets:
* 1) Get Databricks API token stored in specified secret
* 2) Normalize secret structure to be stored at Delta table pipeline_report under inputConfig.tokenSecret nested struct column
* There are two secret types available now - AWS Secrets Manager, Databricks secrets
*/
trait SecretTools[T <: TokenSecretContainer] {
def getApiToken : String
def getTargetTableStruct: TokenSecret
}

object SecretTools {
private val logger: Logger = Logger.getLogger(this.getClass)
type DatabricksTokenSecret = TokenSecret

private class DatabricksSecretTools(tokenSecret : DatabricksTokenSecret) extends SecretTools[DatabricksTokenSecret] {
override def getApiToken: String = {
val scope = tokenSecret.scope
val key = tokenSecret.key
val authMessage = s"Executing with Databricks token located in secret, scope=$scope : key=$key"
logger.log(Level.INFO, authMessage)
dbutils.secrets.get(scope, key)
}

override def getTargetTableStruct: TokenSecret = {
TokenSecret(tokenSecret.scope,tokenSecret.key)
}
}

private class AwsSecretTools(tokenSecret : AwsTokenSecret) extends SecretTools[AwsTokenSecret] {
override def getApiToken: String = {
val secretId = tokenSecret.secretId
val region = tokenSecret.region
val tokenKey = tokenSecret.tokenKey
val authMessage = s"Executing with AWS token located in secret, secretId=$secretId : region=$region : tokenKey=$tokenKey"
logger.log(Level.INFO, authMessage)
AwsSecrets.readApiToken(secretId, region, tokenSecret.tokenKey)
}

override def getTargetTableStruct: TokenSecret = {
TokenSecret(tokenSecret.region, tokenSecret.secretId)
}
}

def apply(secretSource: TokenSecretContainer): SecretTools[_] = {
secretSource match {
case x: AwsTokenSecret => new AwsSecretTools(x)
case y: DatabricksTokenSecret => new DatabricksSecretTools(y)
case _ => throw new IllegalArgumentException(s"${secretSource.toString} not implemented")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ case class SparkDetail()

case class GangliaDetail()

case class TokenSecret(scope: String, key: String)
abstract class TokenSecretContainer extends Product with Serializable
case class TokenSecret(scope: String, key: String) extends TokenSecretContainer
case class AwsTokenSecret(secretId: String, region: String, tokenKey: String = "apiToken") extends TokenSecretContainer

case class DataTarget(databaseName: Option[String], databaseLocation: Option[String], etlDataPathPrefix: Option[String],
consumerDatabaseName: Option[String] = None, consumerDatabaseLocation: Option[String] = None)
Expand Down Expand Up @@ -75,7 +77,7 @@ case class AuditLogConfig(
case class IntelligentScaling(enabled: Boolean = false, minimumCores: Int = 4, maximumCores: Int = 512, coeff: Double = 1.0)

case class OverwatchParams(auditLogConfig: AuditLogConfig,
tokenSecret: Option[TokenSecret] = None,
tokenSecret: Option[TokenSecretContainer] = None,
dataTarget: Option[DataTarget] = None,
badRecordsPath: Option[String] = None,
overwatchScope: Option[Seq[String]] = None,
Expand Down Expand Up @@ -349,9 +351,15 @@ object OverwatchEncoders {
implicit def overwatchScope: org.apache.spark.sql.Encoder[OverwatchScope] =
org.apache.spark.sql.Encoders.kryo[OverwatchScope]

/*
implicit def tokenSecret: org.apache.spark.sql.Encoder[TokenSecret] =
org.apache.spark.sql.Encoders.kryo[TokenSecret]

implicit def tokenSecretContainer: org.apache.spark.sql.Encoder[TokenSecretContainer] =
org.apache.spark.sql.Encoders.kryo[TokenSecretContainer]

*/

implicit def dataTarget: org.apache.spark.sql.Encoder[DataTarget] =
org.apache.spark.sql.Encoders.kryo[DataTarget]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.fasterxml.jackson.module.scala.DefaultScalaModule
import io.delta.tables.DeltaTable
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.hadoop.conf._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.functions._
import org.apache.spark.util.SerializableConfiguration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ class ParamDeserializerTest extends AnyFunSpec {

describe("ParamDeserializer") {

val paramModule: SimpleModule = new SimpleModule()
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
.registerModule(DefaultScalaModule)
.registerModule(paramModule)
.asInstanceOf[ObjectMapper with ScalaObjectMapper]

it("should decode passed token string as AWS secrets") {
val AWSsecrets = """
|{"tokenSecret":{"secretId":"overwatch","region":"us-east-2","tokenKey":"apiToken"}}
|""".stripMargin


val expected = Some(AwsTokenSecret("overwatch", "us-east-2", "apiToken"))
val parsed = mapper.readValue[OverwatchParams](AWSsecrets)
assertResult(expected)(parsed.tokenSecret)
}

it("should decode passed token string as Databricks secrets") {
val Databrickssecrets = """
|{"tokenSecret":{"scope":"overwatch", "key":"test-key"}}
|""".stripMargin

val expected = Some(TokenSecret("overwatch", "test-key"))
val parsed = mapper.readValue[OverwatchParams](Databrickssecrets)
assertResult(expected)(parsed.tokenSecret)
}

it("should decode incomplete parameters") {
val incomplete = """
|{"auditLogConfig":{"azureAuditLogEventhubConfig":{"connectionString":"test","eventHubName":"overwatch-evhub",
Expand All @@ -24,13 +52,6 @@ class ParamDeserializerTest extends AnyFunSpec {
|"workspace_name":"myTestWorkspace", "externalizeOptimizations":"false"}
|""".stripMargin

val paramModule: SimpleModule = new SimpleModule()
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
.registerModule(DefaultScalaModule)
.registerModule(paramModule)
.asInstanceOf[ObjectMapper with ScalaObjectMapper]

val expected = OverwatchParams(
AuditLogConfig(
azureAuditLogEventhubConfig = Some(AzureAuditLogEventhubConfig(
Expand Down