Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 012b2a3

Browse files
authored
Improve Testing and Code Quality (#144)
* Remove RDSJobDBConfigurator * Refactor our EC2Manager. Improve EC2TrainingScriptRunnerTest with better testing. * Fix missing tag in LocalTrainingScriptRunnerIntegTest * Cleanup * Remove EC2TrainingScriptRunnerIntegTest * Remove dependency on aws sdk v1 * Separate out Linux_Docker job
1 parent 178f601 commit 012b2a3

File tree

32 files changed

+878
-663
lines changed

32 files changed

+878
-663
lines changed

aws/aws.gradle.kts

-5
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@ dependencies {
2828
name = "aws-sdk-java",
2929
version = "2.10.12"
3030
)
31-
implementation(
32-
group = "com.amazonaws",
33-
name = "aws-java-sdk",
34-
version = "1.11.674"
35-
)
3631

3732
implementation(
3833
group = "com.beust",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package edu.wpi.axon.aws
2+
3+
import java.util.Base64
4+
import mu.KotlinLogging
5+
import software.amazon.awssdk.services.ec2.Ec2Client
6+
import software.amazon.awssdk.services.ec2.model.Ec2Exception
7+
import software.amazon.awssdk.services.ec2.model.Filter
8+
import software.amazon.awssdk.services.ec2.model.InstanceStateName
9+
import software.amazon.awssdk.services.ec2.model.InstanceType
10+
import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
11+
12+
/**
13+
* Manages interacting with EC2.
14+
*/
15+
class EC2Manager {
16+
17+
private val ec2 = Ec2Client.builder().build()
18+
19+
/**
20+
* Starts a new instance for running a training script.
21+
*
22+
* @param scriptData The data for the EC2 instance to run when it boots. This should not
23+
* contain the entire training script, as that would be too much data. Instead, this script
24+
* should use Axon's CLI to download the training script from S3 at runtime.
25+
* @param instanceType The type of the instance to start.
26+
* @return The ID of the instance that was started.
27+
*/
28+
fun startTrainingInstance(scriptData: String, instanceType: InstanceType): String {
29+
val runInstancesResponse = ec2.runInstances {
30+
it.imageId("ami-04b9e92b5572fa0d1")
31+
.instanceType(instanceType)
32+
.maxCount(1)
33+
.minCount(1)
34+
.userData(scriptData.toBase64())
35+
.securityGroups("axon-autogenerated-ec2-sg")
36+
.instanceInitiatedShutdownBehavior(ShutdownBehavior.TERMINATE)
37+
.iamInstanceProfile { it.name("axon-autogenerated-ec2-instance-profile") }
38+
}
39+
40+
return runInstancesResponse.instances().first().instanceId()
41+
}
42+
43+
/**
44+
* Gets the state of the instance. This includes:
45+
* - Instances that are running
46+
* - Instances that are not running but are not shut down or terminated
47+
*
48+
* @param instanceId The ID of the instance.
49+
* @return The state of the instance.
50+
*/
51+
fun getInstanceState(instanceId: String): InstanceStateName? {
52+
return try {
53+
ec2.describeInstanceStatus {
54+
it.instanceIds(instanceId)
55+
.includeAllInstances(true)
56+
.filters(
57+
Filter.builder().name("instance-state-name").values(
58+
"pending",
59+
"running",
60+
"shutting-down",
61+
"stopping"
62+
).build()
63+
)
64+
}.instanceStatuses().firstOrNull()?.instanceState()?.name()
65+
} catch (ex: Ec2Exception) {
66+
LOGGER.debug(ex) { "Failed to get instance status." }
67+
null
68+
}
69+
}
70+
71+
/**
72+
* Terminates the instance.
73+
*
74+
* @param instanceId The ID of the instance.
75+
*/
76+
fun terminateInstance(instanceId: String) {
77+
ec2.terminateInstances {
78+
it.instanceIds(instanceId)
79+
}
80+
}
81+
82+
private fun String.toBase64() =
83+
Base64.getEncoder().encodeToString(byteInputStream().readAllBytes())
84+
85+
companion object {
86+
private val LOGGER = KotlinLogging.logger { }
87+
}
88+
}

aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt

+8-51
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,26 @@ import edu.wpi.axon.dbdata.TrainingScriptProgress
44
import edu.wpi.axon.tfdata.Dataset
55
import edu.wpi.axon.util.FilePath
66
import java.lang.NumberFormatException
7-
import java.util.Base64
87
import mu.KotlinLogging
98
import org.apache.commons.lang3.RandomStringUtils
109
import org.koin.core.KoinComponent
11-
import software.amazon.awssdk.services.ec2.Ec2Client
12-
import software.amazon.awssdk.services.ec2.model.Ec2Exception
13-
import software.amazon.awssdk.services.ec2.model.Filter
1410
import software.amazon.awssdk.services.ec2.model.InstanceStateName
1511
import software.amazon.awssdk.services.ec2.model.InstanceType
16-
import software.amazon.awssdk.services.ec2.model.ShutdownBehavior
1712

1813
/**
1914
* A [TrainingScriptRunner] that runs the training script on EC2 and hosts datasets and models on
2015
* S3. This implementation requires that the script does not try to manage models with S3 itself:
2116
* this class will handle all of that. The script should just load and save the model from/to its
2217
* current directory.
2318
*
24-
* @param bucketName The name of the S3 bucket to use.
2519
* @param instanceType The type of the EC2 instance to run the training script on.
2620
*/
2721
class EC2TrainingScriptRunner(
28-
bucketName: String,
29-
private val instanceType: InstanceType
22+
private val instanceType: InstanceType,
23+
private val ec2Manager: EC2Manager,
24+
private val s3Manager: S3Manager
3025
) : TrainingScriptRunner, KoinComponent {
3126

32-
private val ec2 by lazy { Ec2Client.builder().build() }
33-
private val s3Manager = S3Manager(bucketName)
34-
3527
private val instanceIds = mutableMapOf<Int, String>()
3628
private val scriptDataMap = mutableMapOf<Int, RunTrainingScriptConfiguration>()
3729

@@ -54,15 +46,13 @@ class EC2TrainingScriptRunner(
5446
}
5547

5648
// The file name for the generated script
49+
@Suppress("MagicNumber")
5750
val scriptFileName = "${RandomStringUtils.randomAlphanumeric(20)}.py"
5851

5952
val newModelName = config.newModelName.filename
6053
val datasetName = config.dataset.progressReportingName
6154

62-
s3Manager.uploadTrainingScript(
63-
scriptFileName,
64-
config.scriptContents
65-
)
55+
s3Manager.uploadTrainingScript(scriptFileName, config.scriptContents)
6656

6757
// Reset the training progress so the script doesn't start in the completed state
6858
s3Manager.setTrainingProgress(newModelName, datasetName, "not started")
@@ -112,18 +102,7 @@ class EC2TrainingScriptRunner(
112102
""".trimMargin()
113103
}
114104

115-
val runInstancesResponse = ec2.runInstances {
116-
it.imageId("ami-04b9e92b5572fa0d1")
117-
.instanceType(instanceType)
118-
.maxCount(1)
119-
.minCount(1)
120-
.userData(scriptForEC2.toBase64())
121-
.securityGroups("axon-autogenerated-ec2-sg")
122-
.instanceInitiatedShutdownBehavior(ShutdownBehavior.TERMINATE)
123-
.iamInstanceProfile { it.name("axon-autogenerated-ec2-instance-profile") }
124-
}
125-
126-
instanceIds[config.id] = runInstancesResponse.instances().first().instanceId()
105+
instanceIds[config.id] = ec2Manager.startTrainingInstance(scriptForEC2, instanceType)
127106
scriptDataMap[config.id] = config
128107
}
129108

@@ -136,23 +115,7 @@ class EC2TrainingScriptRunner(
136115
val newModelName = runTrainingScriptConfiguration.newModelName.filename
137116
val datasetName = runTrainingScriptConfiguration.dataset.progressReportingName
138117

139-
val status = try {
140-
ec2.describeInstanceStatus {
141-
it.instanceIds(instanceIds[jobId]!!)
142-
.includeAllInstances(true)
143-
.filters(
144-
Filter.builder().name("instance-state-name").values(
145-
"pending",
146-
"running",
147-
"shutting-down",
148-
"stopping"
149-
).build()
150-
)
151-
}.instanceStatuses().firstOrNull()?.instanceState()?.name()
152-
} catch (ex: Ec2Exception) {
153-
LOGGER.debug(ex) { "Failed to get instance status." }
154-
null
155-
}
118+
val status = ec2Manager.getInstanceState(instanceIds[jobId]!!)
156119

157120
val heartbeat = s3Manager.getHeartbeat(newModelName, datasetName)
158121
val progress = s3Manager.getTrainingProgress(newModelName, datasetName)
@@ -167,15 +130,9 @@ class EC2TrainingScriptRunner(
167130

168131
override fun cancelScript(jobId: Int) {
169132
require(jobId in instanceIds.keys)
170-
171-
ec2.terminateInstances {
172-
it.instanceIds(instanceIds[jobId]!!)
173-
}
133+
ec2Manager.terminateInstance(instanceIds[jobId]!!)
174134
}
175135

176-
private fun String.toBase64() =
177-
Base64.getEncoder().encodeToString(byteInputStream().readAllBytes())
178-
179136
companion object {
180137
private val LOGGER = KotlinLogging.logger { }
181138

aws/src/main/kotlin/edu/wpi/axon/aws/S3Manager.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class S3Manager(
1515
private val bucketName: String
1616
) {
1717

18-
private val s3 by lazy { S3Client.builder().build() }
18+
private val s3 = S3Client.builder().build()
1919

2020
/**
2121
* Uploads an "untrained" model (one that the user wants to upload to start a job with). Meant

0 commit comments

Comments
 (0)