@@ -4,34 +4,26 @@ import edu.wpi.axon.dbdata.TrainingScriptProgress
4
4
import edu.wpi.axon.tfdata.Dataset
5
5
import edu.wpi.axon.util.FilePath
6
6
import java.lang.NumberFormatException
7
- import java.util.Base64
8
7
import mu.KotlinLogging
9
8
import org.apache.commons.lang3.RandomStringUtils
10
9
import org.koin.core.KoinComponent
11
- import software.amazon.awssdk.services.ec2.Ec2Client
12
- import software.amazon.awssdk.services.ec2.model.Ec2Exception
13
- import software.amazon.awssdk.services.ec2.model.Filter
14
10
import software.amazon.awssdk.services.ec2.model.InstanceStateName
15
11
import software.amazon.awssdk.services.ec2.model.InstanceType
16
- import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
17
12
18
13
/* *
19
14
* A [TrainingScriptRunner] that runs the training script on EC2 and hosts datasets and models on
20
15
* S3. This implementation requires that the script does not try to manage models with S3 itself:
21
16
* this class will handle all of that. The script should just load and save the model from/to its
22
17
* current directory.
23
18
*
24
- * @param bucketName The name of the S3 bucket to use.
25
19
* @param instanceType The type of the EC2 instance to run the training script on.
26
20
*/
27
21
class EC2TrainingScriptRunner (
28
- bucketName : String ,
29
- private val instanceType : InstanceType
22
+ private val instanceType : InstanceType ,
23
+ private val ec2Manager : EC2Manager ,
24
+ private val s3Manager : S3Manager
30
25
) : TrainingScriptRunner, KoinComponent {
31
26
32
- private val ec2 by lazy { Ec2Client .builder().build() }
33
- private val s3Manager = S3Manager (bucketName)
34
-
35
27
private val instanceIds = mutableMapOf<Int , String >()
36
28
private val scriptDataMap = mutableMapOf<Int , RunTrainingScriptConfiguration >()
37
29
@@ -54,15 +46,13 @@ class EC2TrainingScriptRunner(
54
46
}
55
47
56
48
// The file name for the generated script
49
+ @Suppress(" MagicNumber" )
57
50
val scriptFileName = " ${RandomStringUtils .randomAlphanumeric(20 )} .py"
58
51
59
52
val newModelName = config.newModelName.filename
60
53
val datasetName = config.dataset.progressReportingName
61
54
62
- s3Manager.uploadTrainingScript(
63
- scriptFileName,
64
- config.scriptContents
65
- )
55
+ s3Manager.uploadTrainingScript(scriptFileName, config.scriptContents)
66
56
67
57
// Reset the training progress so the script doesn't start in the completed state
68
58
s3Manager.setTrainingProgress(newModelName, datasetName, " not started" )
@@ -112,18 +102,7 @@ class EC2TrainingScriptRunner(
112
102
""" .trimMargin()
113
103
}
114
104
115
- val runInstancesResponse = ec2.runInstances {
116
- it.imageId(" ami-04b9e92b5572fa0d1" )
117
- .instanceType(instanceType)
118
- .maxCount(1 )
119
- .minCount(1 )
120
- .userData(scriptForEC2.toBase64())
121
- .securityGroups(" axon-autogenerated-ec2-sg" )
122
- .instanceInitiatedShutdownBehavior(ShutdownBehavior .TERMINATE )
123
- .iamInstanceProfile { it.name(" axon-autogenerated-ec2-instance-profile" ) }
124
- }
125
-
126
- instanceIds[config.id] = runInstancesResponse.instances().first().instanceId()
105
+ instanceIds[config.id] = ec2Manager.startTrainingInstance(scriptForEC2, instanceType)
127
106
scriptDataMap[config.id] = config
128
107
}
129
108
@@ -136,23 +115,7 @@ class EC2TrainingScriptRunner(
136
115
val newModelName = runTrainingScriptConfiguration.newModelName.filename
137
116
val datasetName = runTrainingScriptConfiguration.dataset.progressReportingName
138
117
139
- val status = try {
140
- ec2.describeInstanceStatus {
141
- it.instanceIds(instanceIds[jobId]!! )
142
- .includeAllInstances(true )
143
- .filters(
144
- Filter .builder().name(" instance-state-name" ).values(
145
- " pending" ,
146
- " running" ,
147
- " shutting-down" ,
148
- " stopping"
149
- ).build()
150
- )
151
- }.instanceStatuses().firstOrNull()?.instanceState()?.name()
152
- } catch (ex: Ec2Exception ) {
153
- LOGGER .debug(ex) { " Failed to get instance status." }
154
- null
155
- }
118
+ val status = ec2Manager.getInstanceState(instanceIds[jobId]!! )
156
119
157
120
val heartbeat = s3Manager.getHeartbeat(newModelName, datasetName)
158
121
val progress = s3Manager.getTrainingProgress(newModelName, datasetName)
@@ -167,15 +130,9 @@ class EC2TrainingScriptRunner(
167
130
168
131
override fun cancelScript (jobId : Int ) {
169
132
require(jobId in instanceIds.keys)
170
-
171
- ec2.terminateInstances {
172
- it.instanceIds(instanceIds[jobId]!! )
173
- }
133
+ ec2Manager.terminateInstance(instanceIds[jobId]!! )
174
134
}
175
135
176
- private fun String.toBase64 () =
177
- Base64 .getEncoder().encodeToString(byteInputStream().readAllBytes())
178
-
179
136
companion object {
180
137
private val LOGGER = KotlinLogging .logger { }
181
138
0 commit comments