Skip to content

Commit

Permalink
added sdk branch to requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxTeselkin committed Jul 4, 2023
1 parent b87b4c8 commit 67b85a5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
2 changes: 1 addition & 1 deletion train/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
supervisely==6.72.56
git+https://github.com/supervisely/supervisely.git@grid-gallery-improvements
ultralytics==8.0.112
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.10.1+cu113
Expand Down
65 changes: 50 additions & 15 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def update_globals(new_dataset_ids):
additional_gallery = GridGallery(
columns_number=3,
show_opacity_slider=False,
enable_zoom=True,
)
additional_gallery_f = Field(additional_gallery, "Additional training results visualization")
additional_gallery_f.hide()
Expand Down Expand Up @@ -721,7 +722,9 @@ def change_file_preview(value):
@additional_config_radio.value_changed
def change_radio(value):
if value == "import template from Team Files":
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
templates = api.file.list(team_id, remote_templates_dir)
if len(templates) == 0:
no_templates_notification.show()
Expand All @@ -736,7 +739,9 @@ def change_radio(value):

@additional_config_template_select.value_changed
def change_template(template):
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
remote_template_path = os.path.join(remote_templates_dir, template)
local_template_path = os.path.join(g.app_data_dir, template)
api.file.download(team_id, remote_template_path, local_template_path)
Expand All @@ -748,7 +753,9 @@ def change_template(template):
@save_template_button.click
def upload_template():
save_template_button.loading = True
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
additional_params = train_settings_editor.get_text()
ryaml = ruamel.yaml.YAML()
additional_params = ryaml.load(additional_params)
Expand Down Expand Up @@ -863,10 +870,15 @@ def start_training():
if task_type != "object detection":
unnecessary_classes = []
for cls in project_meta.obj_classes:
if cls.name in selected_classes and cls.geometry_type.geometry_name() not in necessary_geometries:
if (
cls.name in selected_classes
and cls.geometry_type.geometry_name() not in necessary_geometries
):
unnecessary_classes.append(cls.name)
if len(unnecessary_classes) > 0:
sly.Project.remove_classes(g.project_dir, classes_to_remove=unnecessary_classes, inplace=True)
sly.Project.remove_classes(
g.project_dir, classes_to_remove=unnecessary_classes, inplace=True
)
# transfer project to detection task if necessary
if task_type == "object detection":
sly.Project.to_detection_task(g.project_dir, inplace=True)
Expand All @@ -888,7 +900,9 @@ def start_training():
description="Val split length is 0 after ignoring images. Please check your data",
status="error",
)
raise ValueError("Val split length is 0 after ignoring images. Please check your data")
raise ValueError(
"Val split length is 0 after ignoring images. Please check your data"
)
# split the data
train_set, val_set = get_train_val_sets(g.project_dir, train_val_split, api, project_id)
verify_train_val_sets(train_set, val_set)
Expand Down Expand Up @@ -922,7 +936,9 @@ def download_monitor(monitor, api: sly.Api, progress: sly.Progress):
model_filename = selected_model.lower() + ".pt"
pretrained = True
weights_dst_path = os.path.join(g.app_data_dir, model_filename)
weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
weights_url = (
f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
)
with urlopen(weights_url) as file:
weights_size = file.length

Expand Down Expand Up @@ -1069,10 +1085,16 @@ def on_results_file_changed(filepath, pbar):
# visualize train batch
batch = f"train_batch{x}.jpg"
local_train_batches_path = os.path.join(local_artifacts_dir, batch)
if os.path.exists(local_train_batches_path) and batch not in plotted_train_batches and x < 10:
if (
os.path.exists(local_train_batches_path)
and batch not in plotted_train_batches
and x < 10
):
plotted_train_batches.append(batch)
remote_train_batches_path = os.path.join(remote_images_path, batch)
tf_train_batches_info = api.file.upload(team_id, local_train_batches_path, remote_train_batches_path)
tf_train_batches_info = api.file.upload(
team_id, local_train_batches_path, remote_train_batches_path
)
train_batches_gallery.append(tf_train_batches_info.full_storage_url)
if x == 0:
train_batches_gallery_f.show()
Expand Down Expand Up @@ -1158,8 +1180,12 @@ def train_batch_watcher_func():
# visualize additional training results
confusion_matrix_path = os.path.join(local_artifacts_dir, "confusion_matrix_normalized.png")
if os.path.exists(confusion_matrix_path):
remote_confusion_matrix_path = os.path.join(remote_images_path, "confusion_matrix_normalized.png")
tf_confusion_matrix_info = api.file.upload(team_id, confusion_matrix_path, remote_confusion_matrix_path)
remote_confusion_matrix_path = os.path.join(
remote_images_path, "confusion_matrix_normalized.png"
)
tf_confusion_matrix_info = api.file.upload(
team_id, confusion_matrix_path, remote_confusion_matrix_path
)
additional_gallery.append(tf_confusion_matrix_info.full_storage_url)
additional_gallery_f.show()
pr_curve_path = os.path.join(local_artifacts_dir, "PR_curve.png")
Expand All @@ -1180,18 +1206,24 @@ def train_batch_watcher_func():
pose_f1_curve_path = os.path.join(local_artifacts_dir, "PoseF1_curve.png")
if os.path.exists(pose_f1_curve_path):
remote_pose_f1_curve_path = os.path.join(remote_images_path, "PoseF1_curve.png")
tf_pose_f1_curve_info = api.file.upload(team_id, pose_f1_curve_path, remote_pose_f1_curve_path)
tf_pose_f1_curve_info = api.file.upload(
team_id, pose_f1_curve_path, remote_pose_f1_curve_path
)
additional_gallery.append(tf_pose_f1_curve_info.full_storage_url)
mask_f1_curve_path = os.path.join(local_artifacts_dir, "MaskF1_curve.png")
if os.path.exists(mask_f1_curve_path):
remote_mask_f1_curve_path = os.path.join(remote_images_path, "MaskF1_curve.png")
tf_mask_f1_curve_info = api.file.upload(team_id, mask_f1_curve_path, remote_mask_f1_curve_path)
tf_mask_f1_curve_info = api.file.upload(
team_id, mask_f1_curve_path, remote_mask_f1_curve_path
)
additional_gallery.append(tf_mask_f1_curve_info.full_storage_url)

# rename best checkpoint file
results = pd.read_csv(watch_file)
results.columns = [col.replace(" ", "") for col in results.columns]
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (0.9 * results["metrics/mAP50-95(B)"])
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (
0.9 * results["metrics/mAP50-95(B)"]
)
print("Final results:")
print(results)
best_epoch = results["fitness"].idxmax()
Expand All @@ -1203,7 +1235,10 @@ def train_batch_watcher_func():
# add geometry config to saved weights for pose estimation task
if task_type == "pose estimation":
for obj_class in project_meta.obj_classes:
if obj_class.geometry_type.geometry_name() == "graph" and obj_class.name in selected_classes:
if (
obj_class.geometry_type.geometry_name() == "graph"
and obj_class.name in selected_classes
):
geometry_config = obj_class.geometry_config
break
weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
Expand Down

0 comments on commit 67b85a5

Please sign in to comment.