Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 9e4daf4

Browse files
authored
Create a model loading interface (#133)
* Add a ModelLoader interface instead of using LoadLayersFromHDF5 directly * Rename LoadLayersFromHDF5
1 parent 1d1b95b commit 9e4daf4

File tree

11 files changed

+106
-49
lines changed

11 files changed

+106
-49
lines changed

example-models/src/main/kotlin/edu/wpi/axon/examplemodel/ExampleModelManager.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import arrow.fx.IO
88
import arrow.fx.extensions.fx
99
import edu.wpi.axon.tfdata.Model
1010
import edu.wpi.axon.tfdata.layer.Layer
11-
import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
12-
import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
11+
import edu.wpi.axon.tflayerloader.ModelLoaderFactory
1312
import java.io.File
1413
import org.octogonapus.ktguava.collections.mapNodes
1514

@@ -54,12 +53,13 @@ interface ExampleModelManager {
5453
* @param exampleModelManager The manager to download with.
5554
* @return The configured model.
5655
*/
56+
@Suppress("UnstableApiUsage")
5757
fun downloadAndConfigureExampleModel(
5858
exampleModel: ExampleModel,
5959
exampleModelManager: ExampleModelManager
6060
): IO<Tuple2<Model, File>> = IO.fx {
6161
val file = exampleModelManager.download(exampleModel).bind()
62-
val model = LoadLayersFromHDF5(DefaultLayersToGraph()).load(File(file.absolutePath)).bind()
62+
val model = ModelLoaderFactory().createModeLoader(file.name).load(File(file.absolutePath)).bind()
6363

6464
val freezeLayerTransform: (Layer.MetaLayer) -> Layer.MetaLayer = { layer ->
6565
exampleModel.freezeLayers[layer.name]?.let {

tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5.kt renamed to tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/HDF5ModelLoader.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ import java.io.File
2727
/**
2828
* Loads TensorFlow layers from an HDF5 file.
2929
*/
30-
class LoadLayersFromHDF5(
30+
internal class HDF5ModelLoader(
3131
private val layersToGraph: LayersToGraph
32-
) {
32+
) : ModelLoader {
3333

3434
/**
3535
* Load layers from the [file].
3636
*
3737
* @param file The file to load from.
3838
* @return The layers in the file.
3939
*/
40-
fun load(file: File): IO<Model> = IO {
40+
override fun load(file: File): IO<Model> = IO {
4141
HdfFile(file).use {
4242
val config = it.getAttribute("model_config").data as String
4343
val data = Parser.default().parse(config.byteInputStream()) as JsonObject
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package edu.wpi.axon.tflayerloader
2+
3+
import arrow.fx.IO
4+
import edu.wpi.axon.tfdata.Model
5+
import java.io.File
6+
7+
interface ModelLoader {
8+
9+
/**
10+
* Load a [Model] from the [file].
11+
*
12+
* @param file The file to load from.
13+
* @return A new [Model] containing as much of the information from the [file] as possible.
14+
*/
15+
fun load(file: File): IO<Model>
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package edu.wpi.axon.tflayerloader
2+
3+
class ModelLoaderFactory {
4+
5+
/**
6+
* Creates a new [ModelLoader] based on the extension of the [modelFilename].
7+
*
8+
* @param modelFilename The filename of the model file that is going to be loaded.
9+
* @return A new [ModelLoader].
10+
*/
11+
fun createModeLoader(modelFilename: String): ModelLoader = when {
12+
modelFilename.endsWith(".h5") || modelFilename.endsWith(".hdf5") ->
13+
HDF5ModelLoader(DefaultLayersToGraph())
14+
else -> error("Model file type not supported for model with filename: $modelFilename")
15+
}
16+
}

tf-layer-loader/src/test/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5IntegrationTest.kt renamed to tf-layer-loader/src/test/kotlin/edu/wpi/axon/tflayerloader/HDF5ModelLoaderIntegrationTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import io.kotlintest.matchers.collections.shouldHaveSize
1717
import io.kotlintest.shouldBe
1818
import org.junit.jupiter.api.Test
1919

20-
internal class LoadLayersFromHDF5IntegrationTest {
20+
internal class HDF5ModelLoaderIntegrationTest {
2121

2222
@Test
2323
fun `load from test file 1`() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package edu.wpi.axon.tflayerloader
2+
3+
import kotlin.test.assertTrue
4+
import org.junit.jupiter.api.Test
5+
import org.junit.jupiter.api.assertThrows
6+
7+
internal class ModelLoaderFactoryTest {
8+
9+
private val factory = ModelLoaderFactory()
10+
11+
@Test
12+
fun `test h5 file`() {
13+
assertTrue(factory.createModeLoader("a.h5") is HDF5ModelLoader)
14+
}
15+
16+
@Test
17+
fun `test hdf5 file`() {
18+
assertTrue(factory.createModeLoader("a.hdf5") is HDF5ModelLoader)
19+
}
20+
21+
@Test
22+
fun `test unknown file`() {
23+
assertThrows<IllegalStateException> {
24+
factory.createModeLoader("a.abcd")
25+
}
26+
}
27+
}

tf-layer-loader/src/test/kotlin/edu/wpi/axon/tflayerloader/ModelTestUtil.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import java.io.File
1313
* @param block Will be run with the loaded model.
1414
*/
1515
internal inline fun <reified T : Model> loadModel(filename: String, noinline block: (T) -> Unit) {
16-
LoadLayersFromHDF5(DefaultLayersToGraph()).load(
16+
HDF5ModelLoader(DefaultLayersToGraph()).load(
1717
File(block::class.java.getResource(filename).toURI())
1818
).unsafeRunSync().apply { shouldBeInstanceOf(block) }
1919
}
@@ -25,7 +25,7 @@ internal inline fun <reified T : Model> loadModel(filename: String, noinline blo
2525
* @param stub Used to get the class to get a resource from. Do not use this parameter.
2626
*/
2727
fun loadModelFails(filename: String, stub: () -> Unit = {}) {
28-
LoadLayersFromHDF5(DefaultLayersToGraph()).load(
28+
HDF5ModelLoader(DefaultLayersToGraph()).load(
2929
File(stub::class.java.getResource(filename).toURI())
3030
).attempt().unsafeRunSync().shouldBeLeft()
3131
}

training-test-util/src/main/kotlin/edu/wpi/axon/training/testutil/TrainTestUtil.kt

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ package edu.wpi.axon.training.testutil
33
import arrow.core.Tuple3
44
import arrow.fx.IO
55
import edu.wpi.axon.tfdata.Model
6-
import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
7-
import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
6+
import edu.wpi.axon.tflayerloader.ModelLoaderFactory
87
import io.kotlintest.assertions.arrow.either.shouldBeRight
98
import io.kotlintest.matchers.file.shouldExist
109
import io.kotlintest.matchers.string.shouldNotBeEmpty
@@ -26,8 +25,7 @@ private val LOGGER = KotlinLogging.logger("training-test-util")
2625
*/
2726
fun loadModel(modelName: String, stub: () -> Unit): Pair<Model, String> {
2827
val localModelPath = Paths.get(stub::class.java.getResource(modelName).toURI()).toString()
29-
val layers = LoadLayersFromHDF5(DefaultLayersToGraph())
30-
.load(File(localModelPath))
28+
val layers = ModelLoaderFactory().createModeLoader(localModelPath).load(File(localModelPath))
3129
val model = layers.attempt().unsafeRunSync()
3230
model.shouldBeRight()
3331
return model.b as Model to localModelPath

training/src/main/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGenerator.kt

+15-14
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import edu.wpi.axon.dsl.running
1212
import edu.wpi.axon.dsl.task.ApplyFunctionalLayerDeltaTask
1313
import edu.wpi.axon.dsl.variable.Variable
1414
import edu.wpi.axon.tfdata.Model
15-
import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
16-
import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
15+
import edu.wpi.axon.tflayerloader.ModelLoaderFactory
1716
import java.io.File
1817

1918
/**
@@ -32,17 +31,18 @@ class TrainGeneralModelScriptGenerator(
3231
}
3332
}
3433

35-
private val loadLayersFromHDF5 = LoadLayersFromHDF5(DefaultLayersToGraph())
34+
private val modelLoaderFactory = ModelLoaderFactory()
3635

3736
@Suppress("UNUSED_VARIABLE")
38-
override fun generateScript(): Validated<NonEmptyList<String>, String> =
39-
loadLayersFromHDF5.load(File(trainState.userOldModelPath)).flatMap { userOldModel ->
37+
override fun generateScript(): Validated<NonEmptyList<String>, String> {
38+
val modelLoader = modelLoaderFactory.createModeLoader(trainState.userOldModelPath)
39+
return modelLoader.load(File(trainState.userOldModelPath)).flatMap { userOldModel ->
4040
IO {
4141
require(userOldModel is Model.General)
4242

4343
val script = ScriptGenerator(
44-
DefaultPolymorphicNamedDomainObjectContainer.of(),
45-
DefaultPolymorphicNamedDomainObjectContainer.of()
44+
DefaultPolymorphicNamedDomainObjectContainer.of(),
45+
DefaultPolymorphicNamedDomainObjectContainer.of()
4646
) {
4747
val loadedDataset = loadDataset(trainState).let { dataset ->
4848
if (trainState.userNewModel.input.size == 1) {
@@ -69,18 +69,19 @@ class TrainGeneralModelScriptGenerator(
6969
}
7070

7171
lastTask = compileTrainSave(
72-
trainState,
73-
userOldModel,
74-
newModelVar,
75-
applyLayerDeltaTask,
76-
loadedDataset
72+
trainState,
73+
userOldModel,
74+
newModelVar,
75+
applyLayerDeltaTask,
76+
loadedDataset
7777
)
7878
}
7979

8080
script.code(trainState.generateDebugComments)
8181
}
8282
}.attempt().unsafeRunSync().fold(
83-
{ Throwables.getStackTraceAsString(it).invalidNel() },
84-
{ it }
83+
{ Throwables.getStackTraceAsString(it).invalidNel() },
84+
{ it }
8585
)
86+
}
8687
}

training/src/main/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGenerator.kt

+19-18
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import edu.wpi.axon.dsl.running
1212
import edu.wpi.axon.dsl.task.ApplySequentialLayerDeltaTask
1313
import edu.wpi.axon.dsl.variable.Variable
1414
import edu.wpi.axon.tfdata.Model
15-
import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
16-
import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
15+
import edu.wpi.axon.tflayerloader.ModelLoaderFactory
1716
import java.io.File
1817

1918
/**
@@ -33,24 +32,25 @@ class TrainSequentialModelScriptGenerator(
3332
}
3433
}
3534

36-
private val loadLayersFromHDF5 = LoadLayersFromHDF5(DefaultLayersToGraph())
35+
private val modelLoaderFactory = ModelLoaderFactory()
3736

38-
override fun generateScript(): Validated<NonEmptyList<String>, String> =
39-
loadLayersFromHDF5.load(File(trainState.userOldModelPath)).flatMap { oldModel ->
37+
override fun generateScript(): Validated<NonEmptyList<String>, String> {
38+
val modelLoader = modelLoaderFactory.createModeLoader(trainState.userOldModelPath)
39+
return modelLoader.load(File(trainState.userOldModelPath)).flatMap { oldModel ->
4040
IO {
4141
require(oldModel is Model.Sequential)
4242
require(trainState.userNewModel.batchInputShape.count { it == null } <= 1)
4343
val reshapeArgsFromBatchShape =
44-
trainState.userNewModel.batchInputShape.map { it ?: -1 }
44+
trainState.userNewModel.batchInputShape.map { it ?: -1 }
4545

4646
val script = ScriptGenerator(
47-
DefaultPolymorphicNamedDomainObjectContainer.of(),
48-
DefaultPolymorphicNamedDomainObjectContainer.of()
47+
DefaultPolymorphicNamedDomainObjectContainer.of(),
48+
DefaultPolymorphicNamedDomainObjectContainer.of()
4949
) {
5050
val loadedDataset = reshapeAndScaleLoadedDataset(
51-
loadDataset(trainState),
52-
reshapeArgsFromBatchShape,
53-
255
51+
loadDataset(trainState),
52+
reshapeArgsFromBatchShape,
53+
255
5454
)
5555

5656
val model = loadModel(trainState)
@@ -64,18 +64,19 @@ class TrainSequentialModelScriptGenerator(
6464
}
6565

6666
lastTask = compileTrainSave(
67-
trainState,
68-
oldModel,
69-
newModel,
70-
applyLayerDeltaTask,
71-
loadedDataset
67+
trainState,
68+
oldModel,
69+
newModel,
70+
applyLayerDeltaTask,
71+
loadedDataset
7272
)
7373
}
7474

7575
script.code(trainState.generateDebugComments)
7676
}
7777
}.attempt().unsafeRunSync().fold(
78-
{ Throwables.getStackTraceAsString(it).invalidNel() },
79-
{ it }
78+
{ Throwables.getStackTraceAsString(it).invalidNel() },
79+
{ it }
8080
)
81+
}
8182
}

ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/service/JobService.kt

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ import edu.wpi.axon.tfdata.Dataset
1010
import edu.wpi.axon.tfdata.Model
1111
import edu.wpi.axon.tfdata.loss.Loss
1212
import edu.wpi.axon.tfdata.optimizer.Optimizer
13-
import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
14-
import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
13+
import edu.wpi.axon.tflayerloader.ModelLoaderFactory
1514
import edu.wpi.axon.ui.JobRunner
1615
import java.io.File
1716
import java.nio.file.Paths
@@ -52,8 +51,7 @@ object JobService {
5251

5352
private fun loadModel(modelName: String): Pair<Model, String> {
5453
val localModelPath = Paths.get("/home/salmon/Documents/Axon/training/src/test/resources/edu/wpi/axon/training/$modelName").toString()
55-
val layers = LoadLayersFromHDF5(DefaultLayersToGraph())
56-
.load(File(localModelPath))
54+
val layers = ModelLoaderFactory().createModeLoader(localModelPath).load(File(localModelPath))
5755
val model = layers.attempt().unsafeRunSync()
5856
check(model is Either.Right)
5957
return model.b to localModelPath

0 commit comments

Comments
 (0)