Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 32b942b

Browse files
authored
Support targeting the Coral (#147)
* Add test to load from exported from tf 1.15 * Add PostTrainingQuantizationTask * Add RunEdgeTpuCompilerTask * Targeting MobileNetV2 for the Coral works * Bump azure because of docker push * Rename ModelDeploymentTarget.Normal * Bump azure because of docker push * Move failing test from Coral to Desktop * Add disabled test that reproduces the conv2d issue targeting the coral * Fix running training scripts * Add a form to select the job target * JobRunnerIntegTest working with new result uploading * Cleanup * Documentation and final cleanup * Work on fixing windows build * Work on fixing linux_docker build * Working on fixing windows builds * Don't work with paths manually
1 parent cb795c5 commit 32b942b

File tree

54 files changed

+1455
-385
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1455
-385
lines changed

aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt

+9-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package edu.wpi.axon.aws
33
import edu.wpi.axon.db.data.TrainingScriptProgress
44
import edu.wpi.axon.tfdata.Dataset
55
import edu.wpi.axon.util.FilePath
6+
import java.io.File
67
import mu.KotlinLogging
78
import org.apache.commons.lang3.RandomStringUtils
89
import org.koin.core.KoinComponent
@@ -35,9 +36,6 @@ class EC2TrainingScriptRunner(
3536
require(config.oldModelName is FilePath.S3) {
3637
"Must start from a model in S3. Got: ${config.oldModelName}"
3738
}
38-
require(config.newModelName is FilePath.S3) {
39-
"Must export to a model in S3. Got: ${config.newModelName}"
40-
}
4139
require(config.epochs > 0) {
4240
"Must train for at least one epoch. Got ${config.epochs} epochs."
4341
}
@@ -80,14 +78,14 @@ class EC2TrainingScriptRunner(
8078
|apt-cache policy docker-ce
8179
|apt install -y docker-ce
8280
|systemctl status docker
83-
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.12/axon-0.1.12-py2.py3-none-any.whl
81+
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.14/axon-0.1.14-py2.py3-none-any.whl
8482
|axon create-heartbeat ${config.id}
8583
|axon update-training-progress ${config.id} "initializing"
8684
|axon download-untrained-model "${config.oldModelName.path}"
8785
|$downloadDatasetString
8886
|axon download-training-script "$scriptFileName"
89-
|docker run -v ${'$'}(eval "pwd"):/home wpilib/axon-ci:latest "/usr/bin/python3.6 /home/$scriptFileName"
90-
|axon upload-trained-model "${config.newModelName.filename}"
87+
|docker run -v ${'$'}(eval "pwd"):/home wpilib/axon-ci:latest "/usr/bin/python3.6" "/home/$scriptFileName"
88+
|axon upload-training-results ${config.id} "${config.workingDir}"
9189
|axon update-training-progress ${config.id} "completed"
9290
|axon remove-heartbeat ${config.id}
9391
|shutdown -h now
@@ -108,6 +106,11 @@ class EC2TrainingScriptRunner(
108106
canceller.addJob(config.id, instanceId)
109107
}
110108

109+
override fun listResults(id: Int): List<String> = s3Manager.listTrainingResults(id)
110+
111+
override fun getResult(id: Int, filename: String): File =
112+
s3Manager.downloadTrainingResult(id, filename)
113+
111114
fun getInstanceId(jobId: Int): String {
112115
requireJobIsInMaps(jobId)
113116
return instanceIds[jobId]!!

aws/src/main/kotlin/edu/wpi/axon/aws/LocalTrainingScriptRunner.kt

+17-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ import edu.wpi.axon.db.data.TrainingScriptProgress
44
import edu.wpi.axon.tfdata.Dataset
55
import edu.wpi.axon.util.FilePath
66
import edu.wpi.axon.util.createLocalProgressFilepath
7+
import edu.wpi.axon.util.getOutputModelName
78
import edu.wpi.axon.util.runCommand
89
import java.io.File
9-
import java.nio.file.Files
10-
import java.nio.file.Paths
1110
import kotlin.concurrent.thread
1211
import mu.KotlinLogging
12+
import org.apache.commons.lang3.RandomStringUtils
1313

1414
/**
1515
* Runs the training script on the local machine. Assumes that Axon is running in the
@@ -31,9 +31,6 @@ class LocalTrainingScriptRunner(
3131
require(config.oldModelName is FilePath.Local) {
3232
"Must start from a local model. Got: ${config.oldModelName}"
3333
}
34-
require(config.newModelName is FilePath.Local) {
35-
"Must export to a local model. Got: ${config.newModelName}"
36-
}
3734
require(config.epochs > 0) {
3835
"Must train for at least one epoch. Got ${config.epochs} epochs."
3936
}
@@ -43,7 +40,8 @@ class LocalTrainingScriptRunner(
4340
}
4441
}
4542

46-
val scriptFile = Files.createTempFile("", ".py").toFile().apply {
43+
val scriptFilename = "${config.workingDir}/${RandomStringUtils.randomAlphanumeric(20)}.py"
44+
val scriptFile = File(scriptFilename).apply {
4745
createNewFile()
4846
writeText(config.scriptContents)
4947
}
@@ -83,8 +81,9 @@ class LocalTrainingScriptRunner(
8381
""".trimMargin()
8482
}
8583

86-
val newModelFile =
87-
Paths.get(config.newModelName.path).toFile()
84+
val newModelFile = config.workingDir
85+
.resolve(getOutputModelName(config.oldModelName.filename))
86+
.toFile()
8887
if (newModelFile.exists()) {
8988
scriptProgressMap[config.id] = TrainingScriptProgress.Completed
9089
} else {
@@ -100,6 +99,16 @@ class LocalTrainingScriptRunner(
10099
}
101100
}
102101

102+
override fun listResults(id: Int): List<String> {
103+
requireJobIsInMaps(id)
104+
return scriptDataMap[id]!!.workingDir.toFile().listFiles()!!.map { it.name }
105+
}
106+
107+
override fun getResult(id: Int, filename: String): File {
108+
requireJobIsInMaps(id)
109+
return scriptDataMap[id]!!.workingDir.resolve(filename).toFile()
110+
}
111+
103112
override fun getTrainingProgress(jobId: Int) = progressReporter.getTrainingProgress(jobId)
104113

105114
override fun overrideTrainingProgress(jobId: Int, progress: TrainingScriptProgress) =

aws/src/main/kotlin/edu/wpi/axon/aws/RunTrainingScriptConfiguration.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@ package edu.wpi.axon.aws
22

33
import edu.wpi.axon.tfdata.Dataset
44
import edu.wpi.axon.util.FilePath
5+
import java.nio.file.Path
56

67
/**
78
* The configuration data needed to run a training script.
89
*
910
* @param oldModelName The path to the current model (that will be loaded).
10-
* @param newModelName The path to the new model (that will be saved).
1111
* @param dataset The path to the dataset.
1212
* @param scriptContents The contents of the training script.
1313
* @param epochs The number of epochs the model will be trained for. Must be greater than zero.
1414
* @param id The id of the Job this script is associated with.
1515
*/
1616
data class RunTrainingScriptConfiguration(
1717
val oldModelName: FilePath,
18-
val newModelName: FilePath,
1918
val dataset: Dataset,
2019
val scriptContents: String,
2120
val epochs: Int,
21+
val workingDir: Path,
2222
val id: Int
2323
)

aws/src/main/kotlin/edu/wpi/axon/aws/S3Manager.kt

+11-10
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,23 @@ class S3Manager(
3838
downloadToLocalFile("axon-untrained-models/$filename")
3939

4040
/**
41-
* Uploads a trained model (one that the user wants to test with).
41+
* Lists the training results for the Job.
4242
*
43-
* @param file The local file containing the model to upload. The filename of the uploaded model
44-
* will be the same as the filename of this file.
43+
* @param jobId The ID of the Job.
44+
* @return The filenames of the results.
4545
*/
46-
fun uploadTrainedModel(file: File) = uploadLocalFile(file, "axon-trained-models/${file.name}")
46+
fun listTrainingResults(jobId: Int): List<String> =
47+
listObjectsWithPrefixAndRemovePrefix("axon-training-results/$jobId/")
4748

4849
/**
49-
* Downloads a trained model (which was put in S3 by the training script when it
50-
* ran on EC2). Meant to be used to download to the user's local machine.
50+
* Downloads a training result to a local file.
5151
*
52-
* @param filename The filename of the trained model file.
53-
* @return A local file containing the trained model.
52+
* @param jobId The ID of the Job.
53+
* @param resultFilename The filename of the result to download.
54+
* @return A local file containing the result.
5455
*/
55-
fun downloadTrainedModel(filename: String): File =
56-
downloadToLocalFile("axon-trained-models/$filename")
56+
fun downloadTrainingResult(jobId: Int, resultFilename: String): File =
57+
downloadToLocalFile("axon-training-results/$jobId/$resultFilename")
5758

5859
/**
5960
* Uploads a test data file.

aws/src/main/kotlin/edu/wpi/axon/aws/TrainingScriptRunner.kt

+19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package edu.wpi.axon.aws
22

3+
import java.io.File
4+
35
interface TrainingScriptRunner : TrainingScriptProgressReporter, TrainingScriptCanceller {
46

57
/**
@@ -8,4 +10,21 @@ interface TrainingScriptRunner : TrainingScriptProgressReporter, TrainingScriptC
810
* @param config The data needed to start the script.
911
*/
1012
fun startScript(config: RunTrainingScriptConfiguration)
13+
14+
/**
15+
* Lists the results from running the training script.
16+
*
17+
* @param id The Job ID.
18+
* @return The names of the results.
19+
*/
20+
fun listResults(id: Int): List<String>
21+
22+
/**
23+
* Gets a result as a local file.
24+
*
25+
* @param id The Job ID.
26+
* @param filename The name of the result to download.
27+
* @return A local file containing the result.
28+
*/
29+
fun getResult(id: Int, filename: String): File
1130
}

aws/src/main/kotlin/edu/wpi/axon/aws/preferences/Preferences.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import software.amazon.awssdk.services.ec2.model.InstanceType
1313
*/
1414
@Serializable
1515
data class Preferences(
16-
var defaultEC2NodeType: InstanceType = InstanceType.T2_MICRO,
16+
var defaultEC2NodeType: InstanceType = InstanceType.T2_SMALL,
1717
var statusPollingDelay: Long = 5000
1818
) {
1919

aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt

+12-26
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ import io.kotlintest.shouldThrow
99
import io.mockk.every
1010
import io.mockk.mockk
1111
import io.mockk.verify
12+
import java.io.File
1213
import org.apache.commons.lang3.RandomStringUtils
1314
import org.junit.jupiter.api.Test
15+
import org.junit.jupiter.api.io.TempDir
1416
import software.amazon.awssdk.services.ec2.model.InstanceStateName
1517
import software.amazon.awssdk.services.ec2.model.InstanceType
1618

@@ -23,8 +25,8 @@ internal class EC2TrainingScriptRunnerTest {
2325
)
2426

2527
@Test
26-
fun `test starting script and getting the state until completion`() {
27-
val config = randomRunTrainingScriptConfigurationUsingAWS()
28+
fun `test starting script and getting the state until completion`(@TempDir tempDir: File) {
29+
val config = randomRunTrainingScriptConfigurationUsingAWS(tempDir)
2830
val instanceId = RandomStringUtils.randomAlphanumeric(10)
2931
val instanceType = InstanceType.T2_MICRO
3032

@@ -89,8 +91,8 @@ internal class EC2TrainingScriptRunnerTest {
8991
}
9092

9193
@Test
92-
fun `test starting the script and cancelling it`() {
93-
val config = randomRunTrainingScriptConfigurationUsingAWS()
94+
fun `test starting the script and cancelling it`(@TempDir tempDir: File) {
95+
val config = randomRunTrainingScriptConfigurationUsingAWS(tempDir)
9496
val instanceId = RandomStringUtils.randomAlphanumeric(10)
9597
val instanceType = InstanceType.T2_MICRO
9698

@@ -150,63 +152,47 @@ internal class EC2TrainingScriptRunnerTest {
150152
}
151153

152154
@Test
153-
fun `test running with local old model`() {
155+
fun `test running with local old model`(@TempDir tempDir: File) {
154156
shouldThrow<IllegalArgumentException> {
155157
runner.startScript(
156158
RunTrainingScriptConfiguration(
157159
FilePath.Local("a"),
158-
FilePath.S3("b"),
159160
Dataset.ExampleDataset.FashionMnist,
160161
"",
161162
1,
163+
tempDir.toPath(),
162164
1
163165
)
164166
)
165167
}
166168
}
167169

168170
@Test
169-
fun `test running with local new model`() {
171+
fun `test running with zero epochs`(@TempDir tempDir: File) {
170172
shouldThrow<IllegalArgumentException> {
171173
runner.startScript(
172174
RunTrainingScriptConfiguration(
173175
FilePath.S3("a"),
174-
FilePath.Local("b"),
175-
Dataset.ExampleDataset.FashionMnist,
176-
"",
177-
1,
178-
1
179-
)
180-
)
181-
}
182-
}
183-
184-
@Test
185-
fun `test running with zero epochs`() {
186-
shouldThrow<IllegalArgumentException> {
187-
runner.startScript(
188-
RunTrainingScriptConfiguration(
189-
FilePath.S3("a"),
190-
FilePath.S3("b"),
191176
Dataset.ExampleDataset.FashionMnist,
192177
"",
193178
0,
179+
tempDir.toPath(),
194180
1
195181
)
196182
)
197183
}
198184
}
199185

200186
@Test
201-
fun `test running with local dataset`() {
187+
fun `test running with local dataset`(@TempDir tempDir: File) {
202188
shouldThrow<IllegalArgumentException> {
203189
runner.startScript(
204190
RunTrainingScriptConfiguration(
205191
FilePath.S3("a"),
206-
FilePath.S3("b"),
207192
Dataset.Custom(FilePath.Local("d"), "d"),
208193
"",
209194
1,
195+
tempDir.toPath(),
210196
1
211197
)
212198
)

aws/src/test/kotlin/edu/wpi/axon/aws/LocalTrainingScriptProgressReporterTest.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ internal class LocalTrainingScriptProgressReporterTest {
2727
reporter.addJobAfterRestart(
2828
RunTrainingScriptConfiguration(
2929
FilePath.Local("old.h5"),
30-
FilePath.Local("new.h5"),
3130
Dataset.ExampleDataset.Mnist,
3231
"",
3332
10,
33+
tempDir.toPath(),
3434
id
3535
)
3636
)
@@ -53,10 +53,10 @@ internal class LocalTrainingScriptProgressReporterTest {
5353
reporter.addJobAfterRestart(
5454
RunTrainingScriptConfiguration(
5555
FilePath.Local("old.h5"),
56-
FilePath.Local("new.h5"),
5756
Dataset.ExampleDataset.Mnist,
5857
"",
5958
10,
59+
tempDir.toPath(),
6060
id
6161
)
6262
)
@@ -79,10 +79,10 @@ internal class LocalTrainingScriptProgressReporterTest {
7979
reporter.addJobAfterRestart(
8080
RunTrainingScriptConfiguration(
8181
FilePath.Local("old.h5"),
82-
FilePath.Local("new.h5"),
8382
Dataset.ExampleDataset.Mnist,
8483
"",
8584
10,
85+
tempDir.toPath(),
8686
id
8787
)
8888
)

0 commit comments

Comments
 (0)