forked from BhaskarJoshi-01/DroneControl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
TrainDispatcher.py
52 lines (40 loc) · 1.89 KB
/
TrainDispatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Train dispatcher script for easier training on ADA
import argparse
import json
import os
import tempfile
parser = argparse.ArgumentParser()
parser.add_argument("trainConfigPath", help="Path to train config file")
parser.add_argument("-s", "--steps", default=2_000_000, help="Number of timesteps to train for", type=int)
parser.add_argument("-o", "--obstacles", default=None, help="Number of obstacles", type=int)
parser.add_argument('--local', action='store_true', help='Run on Local Machine')
parser.add_argument("-d", "--dynamic", default=False, help="Use Dynamic Obstacles", type=bool)
args = parser.parse_args()
if args.local:
with open(args.trainConfigPath, 'r') as f:
trainConfig = json.load(f)
taskName = trainConfig["taskName"]
envConfig = trainConfig["envConfigFile"]
modelName = trainConfig["outputModelName"]
os.chdir('SBAgent')
os.system(f"python TrainModel.py {envConfig} {modelName} --steps {args.steps} --dynamic {args.dynamic} -o {args.obstacles}")
else:
with open('trainScriptTemplate.sh', 'r') as f:
script = ''.join(f.readlines())
with open(args.trainConfigPath, 'r') as f:
trainConfig = json.load(f)
taskName = trainConfig["taskName"]
envConfig = trainConfig["envConfigFile"]
modelName = trainConfig["outputModelName"]
script = script.replace("{outputFile}", f"jobOutputs/{taskName}_output.txt")
script = script.replace("{jobName}", f"{taskName}")
script = script.replace("{configFile}", envConfig)
script = script.replace("{outputModelName}", modelName)
script = script.replace("{steps}", str(args.steps))
script = script.replace("{obstacles}", str(args.obstacles))
script = script.replace("{dynamic}", str(args.dynamic))
tmp = tempfile.NamedTemporaryFile()
with open(tmp.name, 'w') as f:
f.write(script)
print(f"Dispatching Train Job for {taskName}")
os.system(f"sbatch {tmp.name}")