Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit e155bb3

Browse files
authored
Add LocalTrainingScriptRunner (#138)
* Remove anything related to interfacing with S3 in the training script * Initial implementation of LocalTrainingScriptRunner * Fix SaveModelTaskTest * Add LocalProgressReportingCallbackTask * Fix integration tests * Add missing dependency * Fix issues with injecting the bucket. Fix bug in LocalPreferencesManager. Add LocalPreferencesManagerTest. * Add a job for running locally. Improve logging. Working on a test failure. * Patch the scripts * Fix Conv32321IntegrationTest * Improve progress reporting for LocalTrainingScriptRunner * Convert Dataset.Custom to use a FilePath instead of assuming it is in S3 * Update util/src/main/kotlin/edu/wpi/axon/util/Util.kt * Address PR comments * Simplify local progress file creation logic * Fix running jobs with EC2. Improve status reporting. Eagerly load AWS bucket. Download untrained models from S3 if necessary. * Add comment
1 parent cdcface commit e155bb3

File tree

59 files changed

+1431
-694
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1431
-694
lines changed

aws/aws.gradle.kts

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies {
2121

2222
api(project(":tf-data"))
2323
api(project(":db-data"))
24+
api(project(":training"))
2425

2526
api(
2627
group = "software.amazon.awssdk",

aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt

+55-25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package edu.wpi.axon.aws
22

33
import edu.wpi.axon.dbdata.TrainingScriptProgress
44
import edu.wpi.axon.tfdata.Dataset
5+
import edu.wpi.axon.util.FilePath
56
import java.lang.NumberFormatException
67
import java.util.Base64
78
import java.util.concurrent.atomic.AtomicLong
@@ -10,6 +11,7 @@ import org.apache.commons.lang3.RandomStringUtils
1011
import org.koin.core.KoinComponent
1112
import software.amazon.awssdk.services.ec2.Ec2Client
1213
import software.amazon.awssdk.services.ec2.model.Ec2Exception
14+
import software.amazon.awssdk.services.ec2.model.Filter
1315
import software.amazon.awssdk.services.ec2.model.InstanceStateName
1416
import software.amazon.awssdk.services.ec2.model.InstanceType
1517
import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
@@ -20,11 +22,12 @@ import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
2022
* this class will handle all of that. The script should just load and save the model from/to its
2123
* current directory.
2224
*
25+
* @param bucketName The name of the S3 bucket to use.
2326
* @param instanceType The type of the EC2 instance to run the training script on.
2427
*/
2528
class EC2TrainingScriptRunner(
2629
bucketName: String,
27-
private val instanceType: InstanceType // TODO: Move this to [startScript]
30+
private val instanceType: InstanceType
2831
) : TrainingScriptRunner, KoinComponent {
2932

3033
private val ec2 by lazy { Ec2Client.builder().build() }
@@ -35,30 +38,32 @@ class EC2TrainingScriptRunner(
3538
private val scriptDataMap = mutableMapOf<Long, RunTrainingScriptConfiguration>()
3639

3740
override fun startScript(
38-
runTrainingScriptConfiguration: RunTrainingScriptConfiguration
41+
config: RunTrainingScriptConfiguration
3942
): Long {
40-
// Check for if the script uses the CLI to manage the model in S3. This class is supposed to
41-
// own working with S3.
42-
require(
43-
!runTrainingScriptConfiguration.scriptContents.contains("download_model") &&
44-
!runTrainingScriptConfiguration.scriptContents.contains("upload_model")
45-
) {
46-
"""
47-
|Cannot start the script because it interfaces with AWS:
48-
|${runTrainingScriptConfiguration.scriptContents}
49-
|
50-
""".trimMargin()
43+
require(config.oldModelName is FilePath.S3) {
44+
"Must start from a model in S3. Got: ${config.oldModelName}"
45+
}
46+
require(config.newModelName is FilePath.S3) {
47+
"Must export to a model in S3. Got: ${config.newModelName}"
48+
}
49+
require(config.epochs > 0) {
50+
"Must train for at least one epoch. Got ${config.epochs} epochs."
51+
}
52+
when (config.dataset) {
53+
is Dataset.Custom -> require(config.dataset.path is FilePath.S3) {
54+
"Custom datasets must be in S3. Got non-local dataset: ${config.dataset}"
55+
}
5156
}
5257

5358
// The file name for the generated script
5459
val scriptFileName = "${RandomStringUtils.randomAlphanumeric(20)}.py"
5560

56-
val newModelName = runTrainingScriptConfiguration.newModelName
57-
val datasetName = runTrainingScriptConfiguration.dataset.nameForS3ProgressReporting
61+
val newModelName = config.newModelName.filename
62+
val datasetName = config.dataset.progressReportingName
5863

5964
s3Manager.uploadTrainingScript(
6065
scriptFileName,
61-
runTrainingScriptConfiguration.scriptContents
66+
config.scriptContents
6267
)
6368

6469
// Reset the training progress so the script doesn't start in the completed state
@@ -69,10 +74,10 @@ class EC2TrainingScriptRunner(
6974

7075
// We need to download custom datasets from S3. Example datasets will be downloaded
7176
// by the script using Keras.
72-
val downloadDatasetString = when (runTrainingScriptConfiguration.dataset) {
77+
val downloadDatasetString = when (config.dataset) {
7378
is Dataset.ExampleDataset -> ""
7479
is Dataset.Custom ->
75-
"""axon download-dataset "${runTrainingScriptConfiguration.dataset.pathInS3}""""
80+
"""axon download-dataset "${config.dataset.path.path}""""
7681
}
7782

7883
val scriptForEC2 = """
@@ -91,7 +96,7 @@ class EC2TrainingScriptRunner(
9196
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.11/axon-0.1.11-py2.py3-none-any.whl
9297
|axon create-heartbeat "$newModelName" "$datasetName"
9398
|axon update-training-progress "$newModelName" "$datasetName" "initializing"
94-
|axon download-untrained-model "${runTrainingScriptConfiguration.oldModelName}"
99+
|axon download-untrained-model "${config.oldModelName.path}"
95100
|$downloadDatasetString
96101
|axon download-training-script "$scriptFileName"
97102
|docker run -v ${'$'}(eval "pwd"):/home wpilib/axon-ci:latest "/usr/bin/python3.6 /home/$scriptFileName"
@@ -122,7 +127,7 @@ class EC2TrainingScriptRunner(
122127

123128
val scriptId = nextScriptId.getAndIncrement()
124129
instanceIds[scriptId] = runInstancesResponse.instances().first().instanceId()
125-
scriptDataMap[scriptId] = runTrainingScriptConfiguration
130+
scriptDataMap[scriptId] = config
126131
return scriptId
127132
}
128133

@@ -132,14 +137,24 @@ class EC2TrainingScriptRunner(
132137
require(scriptId in scriptDataMap.keys)
133138

134139
val runTrainingScriptConfiguration = scriptDataMap[scriptId]!!
135-
val newModelName = runTrainingScriptConfiguration.newModelName
136-
val datasetName = runTrainingScriptConfiguration.dataset.nameForS3ProgressReporting
140+
val newModelName = runTrainingScriptConfiguration.newModelName.filename
141+
val datasetName = runTrainingScriptConfiguration.dataset.progressReportingName
137142

138143
val status = try {
139144
ec2.describeInstanceStatus {
140145
it.instanceIds(instanceIds[scriptId]!!)
146+
.includeAllInstances(true)
147+
.filters(
148+
Filter.builder().name("instance-state-name").values(
149+
"pending",
150+
"running",
151+
"shutting-down",
152+
"stopping"
153+
).build()
154+
)
141155
}.instanceStatuses().firstOrNull()?.instanceState()?.name()
142156
} catch (ex: Ec2Exception) {
157+
LOGGER.warn(ex) { "Failed to get instance status." }
143158
null
144159
}
145160

@@ -166,13 +181,29 @@ class EC2TrainingScriptRunner(
166181
status: InstanceStateName?,
167182
epochs: Int
168183
): TrainingScriptProgress {
184+
LOGGER.debug {
185+
"""
186+
|Heartbeat: $heartbeat
187+
|Progress: $progress
188+
|Instance status: $status
189+
""".trimMargin()
190+
}
191+
169192
val progressAssumingEverythingIsFine = computeProgressAssumingEverythingIsFine(
170193
heartbeat,
171194
progress,
172195
status,
173196
epochs
174197
)
175198

199+
if ((status == InstanceStateName.SHUTTING_DOWN ||
200+
status == InstanceStateName.TERMINATED ||
201+
status == InstanceStateName.STOPPING) &&
202+
(heartbeat != "0" || progress != "completed")
203+
) {
204+
return TrainingScriptProgress.Error
205+
}
206+
176207
return when (heartbeat) {
177208
"0" -> when (progress) {
178209
"not started", "completed" -> progressAssumingEverythingIsFine
@@ -185,8 +216,7 @@ class EC2TrainingScriptRunner(
185216

186217
else -> when (status) {
187218
InstanceStateName.SHUTTING_DOWN, InstanceStateName.TERMINATED,
188-
InstanceStateName.STOPPING, InstanceStateName.STOPPED, null ->
189-
TrainingScriptProgress.Error
219+
InstanceStateName.STOPPING -> TrainingScriptProgress.Error
190220

191221
else -> progressAssumingEverythingIsFine
192222
}
@@ -205,7 +235,7 @@ class EC2TrainingScriptRunner(
205235
"not started" -> when (status) {
206236
InstanceStateName.PENDING -> TrainingScriptProgress.Creating
207237
InstanceStateName.RUNNING -> TrainingScriptProgress.Initializing
208-
else -> TrainingScriptProgress.NotStarted
238+
else -> TrainingScriptProgress.Creating
209239
}
210240

211241
"initializing" -> TrainingScriptProgress.Initializing
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
11
package edu.wpi.axon.aws
22

3+
import arrow.core.None
4+
import arrow.core.Option
5+
import arrow.core.Some
36
import mu.KotlinLogging
47
import software.amazon.awssdk.core.exception.SdkClientException
58
import software.amazon.awssdk.services.s3.S3Client
69

710
/**
8-
* Finds the S3 bucket Axon will work out of. Returns `null` if there is no matching bucket. The AWS
9-
* region MUST be auto-detectable from the environment (like when running on ECS). To use AWS when
10-
* running locally, set `AWS_REGION` to your preferred region. To not use AWS when running locally,
11-
* do not set `AWS_REGION`.
11+
* Finds the S3 bucket Axon will work out of. Returns [None] if there is no matching bucket, which
12+
* causes Axon to run locally and not interface with AWS. The AWS region MUST be auto-detectable
13+
* from the environment (like when running on ECS). To use AWS when running locally, set
14+
* `AWS_REGION` to your preferred region. To not use AWS when running locally, do not set
15+
* `AWS_REGION`.
1216
*
13-
* @return The name of the bucket or `null` if the bucket could not be found.
17+
* @return The name of the bucket or [None] if the bucket could not be found.
1418
*/
15-
fun findAxonS3Bucket() = try {
19+
fun findAxonS3Bucket(): Option<String> = try {
1620
val s3Client = S3Client.builder().build()
1721

1822
val bucket = s3Client.listBuckets().buckets().first {
1923
it.name().startsWith("axon-autogenerated-")
2024
}
2125

2226
LOGGER.info { "Starting with S3 bucket: $bucket" }
23-
bucket.name()
27+
Some(bucket.name())
2428
} catch (e: SdkClientException) {
2529
LOGGER.info(e) { "Not loading credentials because of this exception." }
26-
null
30+
None
2731
}
2832

2933
private val LOGGER = KotlinLogging.logger { }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package edu.wpi.axon.aws
2+
3+
import edu.wpi.axon.dbdata.TrainingScriptProgress
4+
import edu.wpi.axon.tfdata.Dataset
5+
import edu.wpi.axon.util.FilePath
6+
import edu.wpi.axon.util.createProgressFilePath
7+
import edu.wpi.axon.util.runCommand
8+
import java.io.File
9+
import java.lang.NumberFormatException
10+
import java.nio.file.Files
11+
import java.nio.file.Paths
12+
import java.util.concurrent.atomic.AtomicLong
13+
import kotlin.concurrent.thread
14+
import mu.KotlinLogging
15+
16+
/**
17+
* Runs the training script on the local machine. Assumes that Axon is running in the
18+
* wpilib/axon-hosted Docker container.
19+
*/
20+
class LocalTrainingScriptRunner : TrainingScriptRunner {
21+
22+
private val nextScriptId = AtomicLong()
23+
private val scriptDataMap = mutableMapOf<Long, RunTrainingScriptConfiguration>()
24+
private val scriptProgressMap = mutableMapOf<Long, TrainingScriptProgress>()
25+
private val scriptThreadMap = mutableMapOf<Long, Thread>()
26+
27+
override fun startScript(config: RunTrainingScriptConfiguration): Long {
28+
require(config.oldModelName is FilePath.Local) {
29+
"Must start from a local model. Got: ${config.oldModelName}"
30+
}
31+
require(config.newModelName is FilePath.Local) {
32+
"Must export to a local model. Got: ${config.newModelName}"
33+
}
34+
require(config.epochs > 0) {
35+
"Must train for at least one epoch. Got ${config.epochs} epochs."
36+
}
37+
when (config.dataset) {
38+
is Dataset.Custom -> require(config.dataset.path is FilePath.Local) {
39+
"Custom datasets must be local. Got non-local dataset: ${config.dataset}"
40+
}
41+
}
42+
43+
val scriptFile = Files.createTempFile("", ".py").toFile()
44+
scriptFile.createNewFile()
45+
scriptFile.writeText(config.scriptContents)
46+
47+
val modelName = config.newModelName.filename
48+
val datasetName = config.dataset.progressReportingName
49+
50+
// Clear the progress file if there was a previous run
51+
val progressFile = File(createProgressFilePath(modelName, datasetName))
52+
progressFile.parentFile.mkdirs()
53+
progressFile.createNewFile()
54+
progressFile.writeText("0.0")
55+
56+
val scriptId = nextScriptId.getAndIncrement()
57+
scriptDataMap[scriptId] = config
58+
scriptProgressMap[scriptId] = TrainingScriptProgress.Creating
59+
60+
scriptThreadMap[scriptId] = thread {
61+
scriptProgressMap[scriptId] = TrainingScriptProgress.Initializing
62+
63+
runCommand(
64+
listOf("python3.6", scriptFile.absolutePath),
65+
emptyMap(),
66+
null
67+
).attempt().unsafeRunSync().fold(
68+
{
69+
LOGGER.warn(it) { "Training script failed." }
70+
scriptProgressMap[scriptId] = TrainingScriptProgress.Error
71+
},
72+
{ (exitCode, stdOut, stdErr) ->
73+
LOGGER.info {
74+
"""
75+
|Training script completed.
76+
|Process exit code: $exitCode
77+
|Process std out:
78+
|$stdOut
79+
|
80+
|Process std err:
81+
|$stdErr
82+
|
83+
""".trimMargin()
84+
}
85+
86+
val newModelFile =
87+
Paths.get(config.newModelName.path).toFile()
88+
if (newModelFile.exists()) {
89+
scriptProgressMap[scriptId] = TrainingScriptProgress.Completed
90+
} else {
91+
scriptProgressMap[scriptId] = TrainingScriptProgress.Error
92+
}
93+
}
94+
)
95+
}
96+
97+
return scriptId
98+
}
99+
100+
override fun getTrainingProgress(scriptId: Long): TrainingScriptProgress {
101+
require(scriptId in scriptDataMap.keys)
102+
require(scriptId in scriptThreadMap.keys)
103+
require(scriptId in scriptProgressMap.keys)
104+
105+
return if (scriptThreadMap[scriptId]!!.isAlive) {
106+
// Training thread is still running. Try to read the progress file.
107+
when (scriptProgressMap[scriptId]!!) {
108+
// These statuses are reasonable
109+
is TrainingScriptProgress.Initializing -> TrainingScriptProgress.Initializing
110+
is TrainingScriptProgress.Completed -> TrainingScriptProgress.Completed
111+
is TrainingScriptProgress.Error -> TrainingScriptProgress.Error
112+
else -> {
113+
// Otherwise it must be InProgress
114+
val config = scriptDataMap[scriptId]!!
115+
val modelName = config.newModelName.filename
116+
val datasetName = config.dataset.progressReportingName
117+
val progressFile = File(createProgressFilePath(modelName, datasetName))
118+
if (progressFile.exists()) {
119+
try {
120+
TrainingScriptProgress.InProgress(
121+
progressFile.readText().toDouble() / config.epochs
122+
)
123+
} catch (ex: NumberFormatException) {
124+
TrainingScriptProgress.Error
125+
}
126+
} else {
127+
TrainingScriptProgress.InProgress(0.0)
128+
}
129+
}
130+
}
131+
} else {
132+
// Training thread died or is not started yet. If it dies, either it finished and wrote
133+
// Completed to scriptProgressMap or exploded and didn't write Completed. If it is not
134+
// started yet, then the status will still be Creating.
135+
when (scriptProgressMap[scriptId]!!) {
136+
is TrainingScriptProgress.Creating -> TrainingScriptProgress.Creating
137+
is TrainingScriptProgress.Completed -> TrainingScriptProgress.Completed
138+
else -> TrainingScriptProgress.Error
139+
}
140+
}
141+
}
142+
143+
companion object {
144+
private val LOGGER = KotlinLogging.logger { }
145+
}
146+
}

0 commit comments

Comments
 (0)