@@ -12,8 +12,7 @@ import edu.wpi.axon.dsl.running
12
12
import edu.wpi.axon.dsl.task.ApplySequentialLayerDeltaTask
13
13
import edu.wpi.axon.dsl.variable.Variable
14
14
import edu.wpi.axon.tfdata.Model
15
- import edu.wpi.axon.tflayerloader.DefaultLayersToGraph
16
- import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5
15
+ import edu.wpi.axon.tflayerloader.ModelLoaderFactory
17
16
import java.io.File
18
17
19
18
/* *
@@ -33,24 +32,25 @@ class TrainSequentialModelScriptGenerator(
33
32
}
34
33
}
35
34
36
- private val loadLayersFromHDF5 = LoadLayersFromHDF5 ( DefaultLayersToGraph () )
35
+ private val modelLoaderFactory = ModelLoaderFactory ( )
37
36
38
- override fun generateScript (): Validated <NonEmptyList <String >, String> =
39
- loadLayersFromHDF5.load(File (trainState.userOldModelPath)).flatMap { oldModel ->
37
+ override fun generateScript (): Validated <NonEmptyList <String >, String> {
38
+ val modelLoader = modelLoaderFactory.createModeLoader(trainState.userOldModelPath)
39
+ return modelLoader.load(File (trainState.userOldModelPath)).flatMap { oldModel ->
40
40
IO {
41
41
require(oldModel is Model .Sequential )
42
42
require(trainState.userNewModel.batchInputShape.count { it == null } <= 1 )
43
43
val reshapeArgsFromBatchShape =
44
- trainState.userNewModel.batchInputShape.map { it ? : - 1 }
44
+ trainState.userNewModel.batchInputShape.map { it ? : - 1 }
45
45
46
46
val script = ScriptGenerator (
47
- DefaultPolymorphicNamedDomainObjectContainer .of(),
48
- DefaultPolymorphicNamedDomainObjectContainer .of()
47
+ DefaultPolymorphicNamedDomainObjectContainer .of(),
48
+ DefaultPolymorphicNamedDomainObjectContainer .of()
49
49
) {
50
50
val loadedDataset = reshapeAndScaleLoadedDataset(
51
- loadDataset(trainState),
52
- reshapeArgsFromBatchShape,
53
- 255
51
+ loadDataset(trainState),
52
+ reshapeArgsFromBatchShape,
53
+ 255
54
54
)
55
55
56
56
val model = loadModel(trainState)
@@ -64,18 +64,19 @@ class TrainSequentialModelScriptGenerator(
64
64
}
65
65
66
66
lastTask = compileTrainSave(
67
- trainState,
68
- oldModel,
69
- newModel,
70
- applyLayerDeltaTask,
71
- loadedDataset
67
+ trainState,
68
+ oldModel,
69
+ newModel,
70
+ applyLayerDeltaTask,
71
+ loadedDataset
72
72
)
73
73
}
74
74
75
75
script.code(trainState.generateDebugComments)
76
76
}
77
77
}.attempt().unsafeRunSync().fold(
78
- { Throwables .getStackTraceAsString(it).invalidNel() },
79
- { it }
78
+ { Throwables .getStackTraceAsString(it).invalidNel() },
79
+ { it }
80
80
)
81
+ }
81
82
}
0 commit comments