Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit cc5ef38

Browse files
Add more accurate progress reporting (#136)
* Add more descriptive progress reporting. Still need to handle errors in the UI * Simplify EC2TrainingScriptRunner implementation * Simplify EC2TrainingScriptRunner implementation * Improve progress detection and add testing * Catch Ec2Exception if the SDK can't find the instance we just created * Add UI support for new progress bar states * Add a creating progress state * Update the job progress to be creating immediately after running. Add waitForChange. Add JobRunner tests. * Spotless * Formatting Co-authored-by: Austin Shalit <austinshalit@gmail.com>
1 parent 8fbba13 commit cc5ef38

File tree

14 files changed

+567
-202
lines changed

14 files changed

+567
-202
lines changed

aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt

+171-132
Large diffs are not rendered by default.

aws/src/main/kotlin/edu/wpi/axon/aws/S3Manager.kt

+67-6
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,72 @@ class S3Manager(
105105
}.readAllBytes().decodeToString()
106106

107107
/**
108-
* Resets the latest training progress data.
108+
* Sets the training progress data.
109109
*
110110
* @param modelName The filename of the model being trained.
111111
* @param datasetName The filename of the dataset being trained on.
112+
* @param data The data to write to the progress file.
112113
*/
113-
fun resetTrainingProgress(modelName: String, datasetName: String) {
114-
s3.deleteObject {
115-
it.bucket(bucketName).key(createTrainingProgressFilePath(modelName, datasetName))
116-
}
114+
fun setTrainingProgress(modelName: String, datasetName: String, data: String) {
115+
s3.putObject(
116+
PutObjectRequest.builder().bucket(bucketName).key(
117+
createTrainingProgressFilePath(
118+
modelName,
119+
datasetName
120+
)
121+
).build(),
122+
RequestBody.fromString(data)
123+
)
124+
}
125+
126+
/**
127+
* Creates a heartbeat that Axon uses to check if the training script is running properly.
128+
*
129+
* @param modelName The filename of the model being trained.
130+
* @param datasetName The filename of the dataset being trained on.
131+
*/
132+
fun createHeartbeat(modelName: String, datasetName: String) {
133+
s3.putObject(
134+
PutObjectRequest.builder().bucket(bucketName).key(
135+
createHeartbeatFilePath(
136+
modelName,
137+
datasetName
138+
)
139+
).build(),
140+
RequestBody.fromString("1")
141+
)
117142
}
118143

144+
/**
145+
* Removes a heartbeat that Axon uses to check if the training script is running properly.
146+
*
147+
* @param modelName The filename of the model being trained.
148+
* @param datasetName The filename of the dataset being trained on.
149+
*/
150+
fun removeHeartbeat(modelName: String, datasetName: String) {
151+
s3.putObject(
152+
PutObjectRequest.builder().bucket(bucketName).key(
153+
createHeartbeatFilePath(
154+
modelName,
155+
datasetName
156+
)
157+
).build(),
158+
RequestBody.fromString("0")
159+
)
160+
}
161+
162+
/**
163+
* Gets the latest heartbeat.
164+
*
165+
* @param modelName The filename of the model being trained.
166+
* @param datasetName The filename of the dataset being trained on.
167+
* @return The contents of the heartbeat file.
168+
*/
169+
@UseExperimental(ExperimentalStdlibApi::class)
170+
fun getHeartbeat(modelName: String, datasetName: String) = s3.getObject {
171+
it.bucket(bucketName).key(createHeartbeatFilePath(modelName, datasetName))
172+
}.readAllBytes().decodeToString()
173+
119174
/**
120175
* Downloads the preferences file to a local file. Throws an exception if there is no
121176
* preferences file in S3.
@@ -189,8 +244,14 @@ class S3Manager(
189244
it.bucket(bucketName).prefix(prefix)
190245
}.contents().map { it.key().substring(prefix.length) }
191246

247+
private fun createTrainingProgressPrefix(modelName: String, datasetName: String) =
248+
"axon-training-progress/$modelName/$datasetName"
249+
192250
private fun createTrainingProgressFilePath(modelName: String, datasetName: String) =
193-
"axon-training-progress/$modelName/$datasetName/progress.txt"
251+
"${createTrainingProgressPrefix(modelName, datasetName)}/progress.txt"
252+
253+
private fun createHeartbeatFilePath(modelName: String, datasetName: String) =
254+
"${createTrainingProgressPrefix(modelName, datasetName)}/heartbeat.txt"
194255

195256
companion object {
196257
private const val preferencesFilename = "axon-preferences.json"
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package edu.wpi.axon.aws
22

3-
import arrow.fx.IO
43
import edu.wpi.axon.dbdata.TrainingScriptProgress
54

65
interface TrainingScriptRunner {
@@ -11,13 +10,13 @@ interface TrainingScriptRunner {
1110
* @param runTrainingScriptConfiguration The data needed to start the script.
1211
* @return The script id used to query about the script during and after training.
1312
*/
14-
fun startScript(runTrainingScriptConfiguration: RunTrainingScriptConfiguration): IO<Long>
13+
fun startScript(runTrainingScriptConfiguration: RunTrainingScriptConfiguration): Long
1514

1615
/**
1716
* Queries for the current progress state of the script.
1817
*
1918
* @param scriptId The id of the script, from [startScript].
2019
* @return The current progress state of the script.
2120
*/
22-
fun getTrainingProgress(scriptId: Long): IO<TrainingScriptProgress>
21+
fun getTrainingProgress(scriptId: Long): TrainingScriptProgress
2322
}

aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt

+104-6
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
package edu.wpi.axon.aws
22

3+
import edu.wpi.axon.dbdata.TrainingScriptProgress
34
import edu.wpi.axon.tfdata.Dataset
5+
import kotlin.test.assertEquals
46
import org.junit.jupiter.api.Disabled
57
import org.junit.jupiter.api.Test
8+
import org.junit.jupiter.params.ParameterizedTest
9+
import org.junit.jupiter.params.provider.Arguments
10+
import org.junit.jupiter.params.provider.MethodSource
11+
import software.amazon.awssdk.services.ec2.model.InstanceStateName
612
import software.amazon.awssdk.services.ec2.model.InstanceType
713

814
internal class EC2TrainingScriptRunnerTest {
915

16+
private val runner = EC2TrainingScriptRunner(
17+
"axon-autogenerated-5bl5pyn1h8g73kxak0xsnacql02e9i",
18+
InstanceType.T2_MICRO
19+
)
20+
1021
@Test
1122
@Disabled("Needs EC2 supervision.")
1223
fun `test running mnist training script`() {
13-
val runner = EC2TrainingScriptRunner(
14-
"axon-autogenerated-5bl5pyn1h8g73kxak0xsnacql02e9i",
15-
InstanceType.T2_MICRO
16-
)
17-
1824
runner.startScript(
1925
RunTrainingScriptConfiguration(
2026
"custom_fashion_mnist.h5",
@@ -92,6 +98,98 @@ internal class EC2TrainingScriptRunnerTest {
9298
""".trimIndent(),
9399
1
94100
)
95-
).unsafeRunSync()
101+
)
102+
}
103+
104+
@ParameterizedTest
105+
@MethodSource("progressTestSource")
106+
fun `test progress`(
107+
heartbeat: String,
108+
progress: String,
109+
status: InstanceStateName?,
110+
epochs: Int,
111+
expected: TrainingScriptProgress
112+
) {
113+
assertEquals(
114+
expected,
115+
EC2TrainingScriptRunner.computeTrainingScriptProgress(
116+
heartbeat,
117+
progress,
118+
status,
119+
epochs
120+
)
121+
)
122+
}
123+
124+
companion object {
125+
126+
@JvmStatic
127+
@Suppress("unused")
128+
fun progressTestSource() = listOf(
129+
Arguments.of(
130+
"0", "not started", InstanceStateName.PENDING, 1,
131+
TrainingScriptProgress.Creating
132+
),
133+
Arguments.of(
134+
"0", "not started", InstanceStateName.RUNNING, 1,
135+
TrainingScriptProgress.Initializing
136+
),
137+
Arguments.of("0", "not started", null, 1, TrainingScriptProgress.NotStarted),
138+
Arguments.of(
139+
"1",
140+
"not started",
141+
InstanceStateName.PENDING,
142+
1,
143+
TrainingScriptProgress.Error
144+
),
145+
Arguments.of(
146+
"1",
147+
"not started",
148+
InstanceStateName.RUNNING,
149+
1,
150+
TrainingScriptProgress.Error
151+
),
152+
Arguments.of("1", "not started", null, 1, TrainingScriptProgress.Error),
153+
Arguments.of(
154+
"0",
155+
"completed",
156+
InstanceStateName.STOPPING,
157+
1,
158+
TrainingScriptProgress.Completed
159+
),
160+
Arguments.of("0", "completed", null, 1, TrainingScriptProgress.Completed),
161+
Arguments.of("1", "completed", null, 1, TrainingScriptProgress.Error),
162+
Arguments.of(
163+
"1",
164+
"1.0",
165+
InstanceStateName.RUNNING,
166+
1,
167+
TrainingScriptProgress.InProgress(1.0)
168+
),
169+
Arguments.of("1", "1.0", InstanceStateName.STOPPING, 1, TrainingScriptProgress.Error),
170+
Arguments.of("1", "1.0", InstanceStateName.TERMINATED, 1, TrainingScriptProgress.Error),
171+
Arguments.of(
172+
"0",
173+
"initializing",
174+
InstanceStateName.RUNNING,
175+
1,
176+
TrainingScriptProgress.Error
177+
),
178+
Arguments.of(
179+
"1",
180+
"initializing",
181+
InstanceStateName.RUNNING,
182+
1,
183+
TrainingScriptProgress.Initializing
184+
),
185+
Arguments.of(
186+
"2",
187+
"initializing",
188+
InstanceStateName.RUNNING,
189+
1,
190+
TrainingScriptProgress.Error
191+
),
192+
Arguments.of("1", "foo", InstanceStateName.RUNNING, 1, TrainingScriptProgress.Error)
193+
)
96194
}
97195
}

db-data/src/main/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgress.kt

+18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ sealed class TrainingScriptProgress : Comparable<TrainingScriptProgress> {
1616
@Serializable
1717
object NotStarted : TrainingScriptProgress()
1818

19+
/**
20+
* The machine that is going to run the training script is being provisioned.
21+
*/
22+
@Serializable
23+
object Creating : TrainingScriptProgress()
24+
25+
/**
26+
* The machine that is going to run the training script is initializing the environment.
27+
*/
28+
@Serializable
29+
object Initializing : TrainingScriptProgress()
30+
1931
/**
2032
* The training is in progress.
2133
*
@@ -30,6 +42,12 @@ sealed class TrainingScriptProgress : Comparable<TrainingScriptProgress> {
3042
@Serializable
3143
object Completed : TrainingScriptProgress()
3244

45+
/**
46+
* The training script encountered an error.
47+
*/
48+
@Serializable
49+
object Error : TrainingScriptProgress()
50+
3351
fun serialize(): String = Json(
3452
JsonConfiguration.Stable
3553
).stringify(serializer(), this)

ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/JobRunner.kt

+46-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package edu.wpi.axon.ui
22

3+
import arrow.core.Either
34
import arrow.core.None
45
import arrow.fx.IO
56
import arrow.fx.extensions.fx
@@ -12,6 +13,7 @@ import edu.wpi.axon.tfdata.Model
1213
import edu.wpi.axon.training.TrainGeneralModelScriptGenerator
1314
import edu.wpi.axon.training.TrainSequentialModelScriptGenerator
1415
import edu.wpi.axon.training.TrainState
16+
import kotlinx.coroutines.delay
1517
import org.koin.core.KoinComponent
1618
import org.koin.core.inject
1719
import org.koin.core.qualifier.named
@@ -27,11 +29,13 @@ class JobRunner : KoinComponent {
2729
* @param job The [Job] to run.
2830
* @return The script id of the script that was started.
2931
*/
30-
fun startJob(job: Job): Long = IO.fx {
31-
val trainModelScriptGenerator = when (val model = job.userModel) {
32-
is Model.Sequential -> TrainSequentialModelScriptGenerator(toTrainState(job, model))
33-
is Model.General -> TrainGeneralModelScriptGenerator(toTrainState(job, model))
34-
}
32+
fun startJob(job: Job): IO<Long> = IO.fx {
33+
val trainModelScriptGenerator = IO {
34+
when (val model = job.userModel) {
35+
is Model.Sequential -> TrainSequentialModelScriptGenerator(toTrainState(job, model))
36+
is Model.General -> TrainGeneralModelScriptGenerator(toTrainState(job, model))
37+
}
38+
}.bind()
3539

3640
val script = trainModelScriptGenerator.generateScript().fold(
3741
{
@@ -55,29 +59,48 @@ class JobRunner : KoinComponent {
5559
scriptContents = script,
5660
epochs = job.userEpochs
5761
)
58-
).bind()
59-
}.unsafeRunSync()
60-
61-
fun getProgress(id: Long) = scriptRunner.getTrainingProgress(id)
62+
)
63+
}
6264

63-
fun waitForCompleted(id: Long, progressUpdate: (TrainingScriptProgress) -> Unit) {
64-
while (true) {
65-
val shouldBreak = getProgress(id).attempt().unsafeRunSync().fold({
66-
// TODO: More intelligent progress reporting than this. We shouldn't have to catch an exception each
67-
// time
68-
false
69-
}, {
65+
/**
66+
* Waits until the [TrainingScriptProgress] state is either completed or error.
67+
*
68+
* @param id The script id.
69+
* @param progressUpdate A callback that is given the current [TrainingScriptProgress] state
70+
* every time it is polled.
71+
* @return An [IO] for continuation.
72+
*/
73+
fun waitForCompleted(id: Long, progressUpdate: (TrainingScriptProgress) -> Unit): IO<Unit> =
74+
IO.tailRecM(scriptRunner.getTrainingProgress(id)) {
75+
IO {
7076
progressUpdate(it)
71-
it == TrainingScriptProgress.Completed
72-
})
73-
74-
if (shouldBreak) {
75-
break
77+
if (it == TrainingScriptProgress.Completed || it == TrainingScriptProgress.Error) {
78+
Either.Right(Unit)
79+
} else {
80+
delay(5000)
81+
Either.Left(scriptRunner.getTrainingProgress(id))
82+
}
7683
}
84+
}
7785

78-
Thread.sleep(5000)
86+
/**
87+
* Waits until the [TrainingScriptProgress] state changes.
88+
*
89+
* @param id The script id.
90+
* @return An [IO] for continuation.
91+
*/
92+
fun waitForChange(id: Long): IO<Unit> =
93+
IO.tailRecM(scriptRunner.getTrainingProgress(id)) {
94+
IO {
95+
delay(5000)
96+
val newState = scriptRunner.getTrainingProgress(id)
97+
if (it != newState) {
98+
Either.Right(Unit)
99+
} else {
100+
Either.Left(newState)
101+
}
102+
}
79103
}
80-
}
81104

82105
private fun <T : Model> toTrainState(
83106
job: Job,

ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/WebAppListener.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ class WebAppListener : ServletContextListener, KoinComponent {
6767
val localModelPath =
6868
Paths.get("/home/salmon/Documents/Axon/training/src/test/resources/edu/wpi/axon/training/$modelName")
6969
.toString()
70-
val layers = ModelLoaderFactory().createModeLoader(localModelPath).load(File(localModelPath))
70+
val layers =
71+
ModelLoaderFactory().createModeLoader(localModelPath).load(File(localModelPath))
7172
val model = layers.attempt().unsafeRunSync()
7273
check(model is Either.Right)
7374
return model.b to localModelPath

0 commit comments

Comments
 (0)