@@ -2,6 +2,7 @@ package edu.wpi.axon.aws
2
2
3
3
import edu.wpi.axon.dbdata.TrainingScriptProgress
4
4
import edu.wpi.axon.tfdata.Dataset
5
+ import edu.wpi.axon.util.FilePath
5
6
import java.lang.NumberFormatException
6
7
import java.util.Base64
7
8
import java.util.concurrent.atomic.AtomicLong
@@ -10,6 +11,7 @@ import org.apache.commons.lang3.RandomStringUtils
10
11
import org.koin.core.KoinComponent
11
12
import software.amazon.awssdk.services.ec2.Ec2Client
12
13
import software.amazon.awssdk.services.ec2.model.Ec2Exception
14
+ import software.amazon.awssdk.services.ec2.model.Filter
13
15
import software.amazon.awssdk.services.ec2.model.InstanceStateName
14
16
import software.amazon.awssdk.services.ec2.model.InstanceType
15
17
import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
@@ -20,11 +22,12 @@ import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
20
22
* this class will handle all of that. The script should just load and save the model from/to its
21
23
* current directory.
22
24
*
25
+ * @param bucketName The name of the S3 bucket to use.
23
26
* @param instanceType The type of the EC2 instance to run the training script on.
24
27
*/
25
28
class EC2TrainingScriptRunner (
26
29
bucketName : String ,
27
- private val instanceType : InstanceType // TODO: Move this to [startScript]
30
+ private val instanceType : InstanceType
28
31
) : TrainingScriptRunner, KoinComponent {
29
32
30
33
private val ec2 by lazy { Ec2Client .builder().build() }
@@ -35,30 +38,32 @@ class EC2TrainingScriptRunner(
35
38
private val scriptDataMap = mutableMapOf<Long , RunTrainingScriptConfiguration >()
36
39
37
40
override fun startScript (
38
- runTrainingScriptConfiguration : RunTrainingScriptConfiguration
41
+ config : RunTrainingScriptConfiguration
39
42
): Long {
40
- // Check for if the script uses the CLI to manage the model in S3. This class is supposed to
41
- // own working with S3.
42
- require(
43
- ! runTrainingScriptConfiguration.scriptContents.contains(" download_model" ) &&
44
- ! runTrainingScriptConfiguration.scriptContents.contains(" upload_model" )
45
- ) {
46
- """
47
- |Cannot start the script because it interfaces with AWS:
48
- |${runTrainingScriptConfiguration.scriptContents}
49
- |
50
- """ .trimMargin()
43
+ require(config.oldModelName is FilePath .S3 ) {
44
+ " Must start from a model in S3. Got: ${config.oldModelName} "
45
+ }
46
+ require(config.newModelName is FilePath .S3 ) {
47
+ " Must export to a model in S3. Got: ${config.newModelName} "
48
+ }
49
+ require(config.epochs > 0 ) {
50
+ " Must train for at least one epoch. Got ${config.epochs} epochs."
51
+ }
52
+ when (config.dataset) {
53
+ is Dataset .Custom -> require(config.dataset.path is FilePath .S3 ) {
54
+ " Custom datasets must be in S3. Got non-local dataset: ${config.dataset} "
55
+ }
51
56
}
52
57
53
58
// The file name for the generated script
54
59
val scriptFileName = " ${RandomStringUtils .randomAlphanumeric(20 )} .py"
55
60
56
- val newModelName = runTrainingScriptConfiguration .newModelName
57
- val datasetName = runTrainingScriptConfiguration .dataset.nameForS3ProgressReporting
61
+ val newModelName = config .newModelName.filename
62
+ val datasetName = config .dataset.progressReportingName
58
63
59
64
s3Manager.uploadTrainingScript(
60
65
scriptFileName,
61
- runTrainingScriptConfiguration .scriptContents
66
+ config .scriptContents
62
67
)
63
68
64
69
// Reset the training progress so the script doesn't start in the completed state
@@ -69,10 +74,10 @@ class EC2TrainingScriptRunner(
69
74
70
75
// We need to download custom datasets from S3. Example datasets will be downloaded
71
76
// by the script using Keras.
72
- val downloadDatasetString = when (runTrainingScriptConfiguration .dataset) {
77
+ val downloadDatasetString = when (config .dataset) {
73
78
is Dataset .ExampleDataset -> " "
74
79
is Dataset .Custom ->
75
- """ axon download-dataset "${runTrainingScriptConfiguration .dataset.pathInS3 } """"
80
+ """ axon download-dataset "${config .dataset.path.path } """"
76
81
}
77
82
78
83
val scriptForEC2 = """
@@ -91,7 +96,7 @@ class EC2TrainingScriptRunner(
91
96
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.11/axon-0.1.11-py2.py3-none-any.whl
92
97
|axon create-heartbeat "$newModelName " "$datasetName "
93
98
|axon update-training-progress "$newModelName " "$datasetName " "initializing"
94
- |axon download-untrained-model "${runTrainingScriptConfiguration .oldModelName} "
99
+ |axon download-untrained-model "${config .oldModelName.path } "
95
100
|$downloadDatasetString
96
101
|axon download-training-script "$scriptFileName "
97
102
|docker run -v ${' $' } (eval "pwd"):/home wpilib/axon-ci:latest "/usr/bin/python3.6 /home/$scriptFileName "
@@ -122,7 +127,7 @@ class EC2TrainingScriptRunner(
122
127
123
128
val scriptId = nextScriptId.getAndIncrement()
124
129
instanceIds[scriptId] = runInstancesResponse.instances().first().instanceId()
125
- scriptDataMap[scriptId] = runTrainingScriptConfiguration
130
+ scriptDataMap[scriptId] = config
126
131
return scriptId
127
132
}
128
133
@@ -132,14 +137,24 @@ class EC2TrainingScriptRunner(
132
137
require(scriptId in scriptDataMap.keys)
133
138
134
139
val runTrainingScriptConfiguration = scriptDataMap[scriptId]!!
135
- val newModelName = runTrainingScriptConfiguration.newModelName
136
- val datasetName = runTrainingScriptConfiguration.dataset.nameForS3ProgressReporting
140
+ val newModelName = runTrainingScriptConfiguration.newModelName.filename
141
+ val datasetName = runTrainingScriptConfiguration.dataset.progressReportingName
137
142
138
143
val status = try {
139
144
ec2.describeInstanceStatus {
140
145
it.instanceIds(instanceIds[scriptId]!! )
146
+ .includeAllInstances(true )
147
+ .filters(
148
+ Filter .builder().name(" instance-state-name" ).values(
149
+ " pending" ,
150
+ " running" ,
151
+ " shutting-down" ,
152
+ " stopping"
153
+ ).build()
154
+ )
141
155
}.instanceStatuses().firstOrNull()?.instanceState()?.name()
142
156
} catch (ex: Ec2Exception ) {
157
+ LOGGER .warn(ex) { " Failed to get instance status." }
143
158
null
144
159
}
145
160
@@ -166,13 +181,29 @@ class EC2TrainingScriptRunner(
166
181
status : InstanceStateName ? ,
167
182
epochs : Int
168
183
): TrainingScriptProgress {
184
+ LOGGER .debug {
185
+ """
186
+ |Heartbeat: $heartbeat
187
+ |Progress: $progress
188
+ |Instance status: $status
189
+ """ .trimMargin()
190
+ }
191
+
169
192
val progressAssumingEverythingIsFine = computeProgressAssumingEverythingIsFine(
170
193
heartbeat,
171
194
progress,
172
195
status,
173
196
epochs
174
197
)
175
198
199
+ if ((status == InstanceStateName .SHUTTING_DOWN ||
200
+ status == InstanceStateName .TERMINATED ||
201
+ status == InstanceStateName .STOPPING ) &&
202
+ (heartbeat != " 0" || progress != " completed" )
203
+ ) {
204
+ return TrainingScriptProgress .Error
205
+ }
206
+
176
207
return when (heartbeat) {
177
208
" 0" -> when (progress) {
178
209
" not started" , " completed" -> progressAssumingEverythingIsFine
@@ -185,8 +216,7 @@ class EC2TrainingScriptRunner(
185
216
186
217
else -> when (status) {
187
218
InstanceStateName .SHUTTING_DOWN , InstanceStateName .TERMINATED ,
188
- InstanceStateName .STOPPING , InstanceStateName .STOPPED , null ->
189
- TrainingScriptProgress .Error
219
+ InstanceStateName .STOPPING -> TrainingScriptProgress .Error
190
220
191
221
else -> progressAssumingEverythingIsFine
192
222
}
@@ -205,7 +235,7 @@ class EC2TrainingScriptRunner(
205
235
" not started" -> when (status) {
206
236
InstanceStateName .PENDING -> TrainingScriptProgress .Creating
207
237
InstanceStateName .RUNNING -> TrainingScriptProgress .Initializing
208
- else -> TrainingScriptProgress .NotStarted
238
+ else -> TrainingScriptProgress .Creating
209
239
}
210
240
211
241
" initializing" -> TrainingScriptProgress .Initializing
0 commit comments