Skip to content

Commit

Permalink
Fix writing of model in the local file system (#528)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanmitra committed Nov 17, 2020
1 parent bc507f2 commit 13ad9cd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
55 changes: 35 additions & 20 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModelWriter.scala
Expand Up @@ -31,15 +31,14 @@
package com.salesforce.op

import java.io.File
import java.nio.charset.StandardCharsets

import com.salesforce.op.features.FeatureJsonHelper
import com.salesforce.op.filters.RawFeatureFilterResults
import com.salesforce.op.stages.{OPStage, OpPipelineStageWriter}
import com.salesforce.op.utils.spark.{JobGroupUtil, OpStep}
import enumeratum._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.ml.util.MLWriter
import org.json4s.JsonAST.{JArray, JObject, JString}
import org.json4s.JsonDSL._
Expand All @@ -59,11 +58,39 @@ class OpWorkflowModelWriter(val model: OpWorkflowModel) extends MLWriter {

implicit val jsonFormats: Formats = DefaultFormats

protected var modelStagingDir: String = WorkflowFileReader.modelStagingDir

/**
* Set the local folder to copy and unpack stored model to for loading
*/
def setModelStagingDir(localDir: String): this.type = {
modelStagingDir = localDir
this
}

override protected def saveImpl(path: String): Unit = {
JobGroupUtil.withJobGroup(OpStep.ModelIO) {
sc.parallelize(Seq(toJsonString(path)), 1)
.saveAsTextFile(OpWorkflowModelReadWriteShared.jsonPath(path), classOf[GzipCodec])
}(this.sparkSession)
val conf = new Configuration()
val localFileSystem = FileSystem.getLocal(conf)
val localPath = localFileSystem.makeQualified(new Path(modelStagingDir))
localFileSystem.delete(localPath, true)
val raw = new Path(localPath, WorkflowFileReader.rawModel)

val rawPathStr = raw.toString
val modelJson = toJsonString(rawPathStr)
val jsonPath = OpWorkflowModelReadWriteShared.jsonPath(rawPathStr)
val os = localFileSystem.create(new Path(jsonPath))
try {
os.write(modelJson.getBytes(StandardCharsets.UTF_8.toString))
} finally {
os.close()
}

val compressed = new Path(localPath, WorkflowFileReader.zipModel)
ZipUtil.pack(new File(raw.toUri.getPath), new File(compressed.toUri.getPath))

val finalPath = new Path(path, WorkflowFileReader.zipModel)
val destinationFileSystem = finalPath.getFileSystem(conf)
destinationFileSystem.moveFromLocalFile(compressed, finalPath)
}

/**
Expand Down Expand Up @@ -207,21 +234,9 @@ object OpWorkflowModelWriter {
overwrite: Boolean = true,
modelStagingDir: String = WorkflowFileReader.modelStagingDir
): Unit = {
val localPath = new Path(modelStagingDir)
val conf = new Configuration()
val localFileSystem = FileSystem.getLocal(conf)
if (overwrite) localFileSystem.delete(localPath, true)
val raw = new Path(modelStagingDir, WorkflowFileReader.rawModel)

val w = new OpWorkflowModelWriter(model)
val w = new OpWorkflowModelWriter(model).setModelStagingDir(modelStagingDir)
val writer = if (overwrite) w.overwrite() else w
writer.save(raw.toString)
val compressed = new Path(modelStagingDir, WorkflowFileReader.zipModel)
ZipUtil.pack(new File(raw.toString), new File(compressed.toString))

val finalPath = new Path(path, WorkflowFileReader.zipModel)
val destinationFileSystem = finalPath.getFileSystem(conf)
destinationFileSystem.moveFromLocalFile(compressed, finalPath)
writer.save(path)
}

/**
Expand Down
Expand Up @@ -232,7 +232,13 @@ class OpWorkflowRunnerTest extends AsyncFlatSpec with PassengerSparkFixtureTest
dirFile.isDirectory shouldBe true
// TODO: maybe do a thorough files inspection here
val files = FileUtils.listFiles(dirFile, null, true)
files.asScala.map(_.toString).exists(_.contains("_SUCCESS")) shouldBe true
val fileNames = files.asScala.map(_.getName)
if (outFile.getAbsolutePath.endsWith("/model")) {
fileNames should contain ("op-model.json")
}
else {
fileNames should contain ("_SUCCESS")
}
files.size > 1
}
res shouldBe a[R]
Expand Down

0 comments on commit 13ad9cd

Please sign in to comment.