-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
184 lines (167 loc) · 5.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
""" The main function that is the interface to the package cvmt """
import pytorch_lightning as pl
pl.seed_everything(100, workers=True)
import git
import argparse
import sys
import wandb
from cvmt.data import prep_all_datasets
from cvmt.ml import (
train_test_split,
trainer_edge_detection_single_task,
trainer_v_landmarks_single_task,
tester_v_landmarks_single_task,
)
from cvmt.verifier import verify_model_perf
from cvmt.utils import (
load_yaml_params,
nested_dict_to_easydict,
)
from cvmt.inference.inference import predict_image_cmd_interface
from easydict import EasyDict
STEPS = ["data_prep", "train_test_split", "train", "verify", "test", "inference"]
TRAINING_TASKS = [
"v_landmarks",
"edges",
]
VERIFICATION_SPLIT = ["val", "test"]
CONFIG_PARAMS_PATH = "configs/params.yaml"
def main(
params: EasyDict,
step: str,
training_task: str = "v_landmarks",
verify_split: str = "val",
filepath: str = "",
pix2cm: float = 10.0,
) -> None:
"""Main function to interact with cvmt library.
Args:
params: An EasyDict of all the parameters needed to interact with the library. See `configs/params.yaml` for more info.
step: An string for the name of the step to run. Options are ["data_prep", "train_test_split", "train"].
training_task: An string for the name of the training task to run. Options are ["v_landmarks", "edges"].
verify_split: An string for the name of the data split to use for the input of verification. Options are ["val", "test"].
Returns:
None
"""
# setup wandb
try:
config_wandb(
params,
)
except Exception as e:
print(e)
print(
"Make sure to export your private `wandb_api_key` into your terminal."
"Alternatively, follow the instructions in the README for the creation of a `.env` file in your configs directory."
)
# Execute the selected function
if step == STEPS[0]:
print(f"** Running {step}")
prep_all_datasets(params)
elif step == STEPS[1]:
print(f"** Running {step}")
train_test_split(params)
elif step == STEPS[2]:
print(f"** Running {step}")
if training_task == TRAINING_TASKS[0]:
print(f"** Running training for {training_task}")
trainer_v_landmarks_single_task(
params,
)
elif training_task == TRAINING_TASKS[1]:
print(f"** Running training for {training_task}")
trainer_edge_detection_single_task(
params,
)
else:
print(f"Unknown training_task is supplied to the command: {training_task}")
sys.exit(1)
elif step == STEPS[3]:
print(f"** Running {step}")
if verify_split in VERIFICATION_SPLIT:
print(f"** Running verification for {verify_split}")
verify_model_perf(params, split=verify_split)
else:
print(f"Unknown verify_split is supplied to the command: {verify_split}")
sys.exit(1)
elif step == STEPS[4]:
print(f"** Running {step}")
tester_v_landmarks_single_task(
params,
)
elif step == STEPS[5]:
print(f"** Running {step}")
stage = predict_image_cmd_interface(
params, filepath=filepath, px2cm_ratio=pix2cm
)
print(f"******* bone age maturity stage for the image is {stage} ********")
elif (step not in STEPS) and (step is not None):
print(f"Unknown step is supplied to the command: {step}")
sys.exit(1)
else:
print("****** Running prep_all_datasets ****** ")
prep_all_datasets(params)
print("****** Running train_test_split ****** ")
train_test_split(params)
print(
"****** Running trainer_v_landmarks_single_task without pretraining! ****** "
)
trainer_v_landmarks_single_task(params)
return None
def parse_arguments():
parser = argparse.ArgumentParser(description="Read command line arguments.")
parser.add_argument(
"--step",
type=str,
help="pipeline step",
)
parser.add_argument(
"--training-task", type=str, help="training_task", default="v_landmarks"
)
parser.add_argument(
"--verify-split", type=str, help="verification input data split", default="val"
)
parser.add_argument(
"--filepath",
type=str,
help="path to the image for inference",
)
parser.add_argument(
"--pix2cm",
type=float,
help="pixel to centimeter ratio as depiced on the image ruler",
)
args = parser.parse_args()
return args
def get_git_info():
repo = git.Repo(search_parent_directories=True)
commit_hash = repo.head.object.hexsha
branch = repo.active_branch.name
return commit_hash, branch
def config_wandb(
params,
):
# make sure not to re-use an old run
wandb.finish()
# login to the wandb sever and initialize
wandb.login()
config = dict(params)
run = wandb.init(
**params.WANDB.INIT,
config=config,
)
# log code
wandb.run.log_code(".")
# log code commit hash and branch
commit_hash, branch = get_git_info()
run.summary["git_commit_hash"] = commit_hash
run.summary["git_branch"] = branch
if __name__ == "__main__":
# load params
default_params: EasyDict = nested_dict_to_easydict(
load_yaml_params(CONFIG_PARAMS_PATH)
)
# Parse command-line arguments
args = vars(parse_arguments())
# run the main
main(params=default_params, **args)