Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Like validate, for training use determine classes, train_label_infos from config_helper #171

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Add support to train subject on different pixel sizes (#143)
- Add support to overrule configuration parameters via command line arguments (#152)
- Several small improvements (#128)
- Add command to (only) validate a training dataset (#133)

### Bugs fixed

Expand Down
1 change: 1 addition & 0 deletions orthoseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from orthoseg.train import train
from orthoseg.predict import predict
from orthoseg.postprocess import postprocess
from orthoseg.validate import validate


def _get_version():
Expand Down
72 changes: 70 additions & 2 deletions orthoseg/helpers/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import pprint
import re
import tempfile
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

from orthoseg.util import config_util
from orthoseg.util.ows_util import FileLayerSource, WMSLayerSource
from orthoseg.lib.prepare_traindatasets import LabelInfo
from orthoseg.lib import prepare_traindatasets as prep


# Get a logger...
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -168,19 +170,85 @@
return tmp_dir


def prepare_traindatasets() -> Tuple[Path, int]:
"""
Create the train datasets (train, validation, test).

Returns:
Tuple[Path, int]: training directory and traindata id
"""
# Create the output dir's if they don't exist yet...
for dir in [
dirs.getpath("project_dir"),
dirs.getpath("training_dir"),
]:
if dir and not dir.exists():
dir.mkdir()

Check warning on line 186 in orthoseg/helpers/config_helper.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/helpers/config_helper.py#L186

Added line #L186 was not covered by tests

# Create the train datasets (train, validation, test)
force_model_traindata_id = train.getint("force_model_traindata_id")
if force_model_traindata_id > -1:
training_dir = dirs.getpath("training_dir") / f"{force_model_traindata_id:02d}"
traindata_id = force_model_traindata_id

Check warning on line 192 in orthoseg/helpers/config_helper.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/helpers/config_helper.py#L191-L192

Added lines #L191 - L192 were not covered by tests
else:
logger.info("Prepare train, validation and test data")
training_dir, traindata_id = prep.prepare_traindatasets(
label_infos=get_train_label_infos(),
classes=determine_classes(),
image_layers=image_layers,
training_dir=dirs.getpath("training_dir"),
labelname_column=train.get("labelname_column"),
image_pixel_x_size=train.getfloat("image_pixel_x_size"),
image_pixel_y_size=train.getfloat("image_pixel_y_size"),
image_pixel_width=train.getint("image_pixel_width"),
image_pixel_height=train.getint("image_pixel_height"),
ssl_verify=general["ssl_verify"],
)
return (training_dir, traindata_id)


def get_train_label_infos() -> List[LabelInfo]:
"""
Searches and returns LabelInfos that can be used to create a training dataset.

Returns:
List[LabelInfo]: List of LabelInfos found.
"""
return _prepare_train_label_infos(
train_label_infos = _prepare_train_label_infos(
labelpolygons_pattern=train.getpath("labelpolygons_pattern"),
labellocations_pattern=train.getpath("labellocations_pattern"),
label_datasources=train.getdict("label_datasources", None),
image_layers=image_layers,
)
if train_label_infos is None or len(train_label_infos) == 0:
raise ValueError(

Check warning on line 224 in orthoseg/helpers/config_helper.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/helpers/config_helper.py#L224

Added line #L224 was not covered by tests
"No valid label file config found in train.label_datasources or "
f"with patterns {train.get('labelpolygons_pattern')} and "
f"{train.get('labellocations_pattern')}"
)
return train_label_infos


def determine_classes():
"""
Determine classes.

Raises:
Exception: Error reading classes

Returns:
any: classes
"""
try:
classes = train.getdict("classes")

# If the burn_value property isn't supplied for the classes, add them
for class_id, (classname) in enumerate(classes):
if "burn_value" not in classes[classname]:
classes[classname]["burn_value"] = class_id
return classes
except Exception as ex:
raise Exception(f"Error reading classes: {train.get('classes')}") from ex

Check warning on line 251 in orthoseg/helpers/config_helper.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/helpers/config_helper.py#L250-L251

Added lines #L250 - L251 were not covered by tests


def _read_layer_config(layer_config_filepath: Path) -> dict:
Expand Down
63 changes: 3 additions & 60 deletions orthoseg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,71 +97,14 @@ def train(config_path: Path, config_overrules: List[str] = []):
logger.debug(f"Config used: \n{conf.pformat_config()}")

try:
# First check if the segment_subject has a valid name
segment_subject = conf.general["segment_subject"]
if segment_subject == "MUST_OVERRIDE":
raise Exception(
"segment_subject must be overridden in the subject specific config file"
)
elif "_" in segment_subject:
raise Exception(f"segment_subject cannot contain '_': {segment_subject}")

# Create the output dir's if they don't exist yet...
for dir in [
conf.dirs.getpath("project_dir"),
conf.dirs.getpath("training_dir"),
]:
if dir and not dir.exists():
dir.mkdir()

# If the training data doesn't exist yet, create it
# -------------------------------------------------
train_label_infos = conf.get_train_label_infos()
if train_label_infos is None or len(train_label_infos) == 0:
raise ValueError(
"No valid label file config found in train.label_datasources or "
f"with patterns {conf.train.get('labelpolygons_pattern')} and "
f"{conf.train.get('labellocations_pattern')}"
)

# Determine the projection of (the first) train layer... it will be used for all
train_image_layer = train_label_infos[0].image_layer
train_image_layer = conf.get_train_label_infos()[0].image_layer
train_projection = conf.image_layers[train_image_layer]["projection"]

# Determine classes
try:
classes = conf.train.getdict("classes")

# If the burn_value property isn't supplied for the classes, add them
for class_id, (classname) in enumerate(classes):
if "burn_value" not in classes[classname]:
classes[classname]["burn_value"] = class_id
except Exception as ex:
raise Exception(
f"Error reading classes: {conf.train.get('classes')}"
) from ex
classes = conf.determine_classes()

# Now create the train datasets (train, validation, test)
force_model_traindata_id = conf.train.getint("force_model_traindata_id")
if force_model_traindata_id > -1:
training_dir = (
conf.dirs.getpath("training_dir") / f"{force_model_traindata_id:02d}"
)
traindata_id = force_model_traindata_id
else:
logger.info("Prepare train, validation and test data")
training_dir, traindata_id = prep.prepare_traindatasets(
label_infos=train_label_infos,
classes=classes,
image_layers=conf.image_layers,
training_dir=conf.dirs.getpath("training_dir"),
labelname_column=conf.train.get("labelname_column"),
image_pixel_x_size=conf.train.getfloat("image_pixel_x_size"),
image_pixel_y_size=conf.train.getfloat("image_pixel_y_size"),
image_pixel_width=conf.train.getint("image_pixel_width"),
image_pixel_height=conf.train.getint("image_pixel_height"),
ssl_verify=conf.general["ssl_verify"],
)
training_dir, traindata_id = conf.prepare_traindatasets()

# Send mail that we are starting train
email_helper.sendmail(f"Start train for config {config_path.stem}")
Expand Down
122 changes: 122 additions & 0 deletions orthoseg/validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Module to make it easy to start a validating session.
"""

import argparse
import logging
from pathlib import Path
import shutil
import sys
import traceback
from typing import List


from orthoseg.helpers import config_helper as conf
from orthoseg.helpers import email_helper
from orthoseg.lib import prepare_traindatasets as prep
from orthoseg.util import log_util

# Get a logger...
logger = logging.getLogger(__name__)


def _validate_args(args) -> argparse.Namespace:
# Interprete arguments
parser = argparse.ArgumentParser(add_help=False)

# Required arguments
required = parser.add_argument_group("Required arguments")
required.add_argument(
"-c", "--config", type=str, required=True, help="The config file to use"
)

# Optional arguments
optional = parser.add_argument_group("Optional arguments")
# Add back help
optional.add_argument(
"-h",
"--help",
action="help",
default=argparse.SUPPRESS,
help="Show this help message and exit",
)
optional.add_argument(
"config_overrules",
nargs="*",
help=(
"Supply any number of config overrules like this: "
"<section>.<parameter>=<value>"
),
)

return parser.parse_args(args)


def validate(config_path: Path, config_overrules: List[str] = []):
"""
Run a validating session for the config specified.

Args:
config_path (Path): Path to the config file to use.
config_overrules (List[str], optional): list of config options that will
overrule other ways to supply configuration. They should be specified in the
form of "<section>.<parameter>=<value>". Defaults to [].
"""
# Init
# Load the config and save in a bunch of global variables so it
# is accessible everywhere
conf.read_orthoseg_config(config_path, overrules=config_overrules)

# Init logging
log_util.clean_log_dir(
log_dir=conf.dirs.getpath("log_dir"),
nb_logfiles_tokeep=conf.logging.getint("nb_logfiles_tokeep"),
)
global logger
logger = log_util.main_log_init(conf.dirs.getpath("log_dir"), __name__)

# Log start
logger.info(f"Start validate for config {config_path.stem}")
logger.debug(f"Config used: \n{conf.pformat_config()}")

try:
# Now create the train datasets (train, validation, test)
training_dir, traindata_id = conf.prepare_traindatasets()

# Send mail that we are starting train
email_helper.sendmail(f"Start validate for config {config_path.stem}")
logger.info(
f"Traindata dir to use is {training_dir}, with traindata_id: {traindata_id}"
)
except Exception as ex:
message = f"ERROR while running validate for task {config_path.stem}"
logger.exception(message)
if isinstance(ex, prep.ValidationError):
message_body = f"Validation error: {ex.to_html()}"

Check warning on line 95 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L91-L95

Added lines #L91 - L95 were not covered by tests
else:
message_body = f"Exception: {ex}<br/><br/>{traceback.format_exc()}"
email_helper.sendmail(subject=message, body=message_body)
raise Exception(message) from ex

Check warning on line 99 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L97-L99

Added lines #L97 - L99 were not covered by tests
finally:
if conf.tmp_dir is not None:
shutil.rmtree(conf.tmp_dir, ignore_errors=True)


def main():
"""
Run validate.
"""
try:

Check warning on line 109 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L109

Added line #L109 was not covered by tests
# Interprete arguments
args = _validate_args(sys.argv[1:])

Check warning on line 111 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L111

Added line #L111 was not covered by tests

# Run!
validate(config_path=Path(args.config), config_overrules=args.config_overrules)
except Exception as ex:
logger.exception(f"Error: {ex}")
raise

Check warning on line 117 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L114-L117

Added lines #L114 - L117 were not covered by tests


# If the script is ran directly...
if __name__ == "__main__":
main()

Check warning on line 122 in orthoseg/validate.py

View check run for this annotation

Codecov / codecov/patch

orthoseg/validate.py#L122

Added line #L122 was not covered by tests
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
entry_points="""
[console_scripts]
orthoseg_load_images=orthoseg.load_images:main
orthoseg_validate=orthoseg.validate:main
orthoseg_train=orthoseg.train:main
orthoseg_predict=orthoseg.predict:main
orthoseg_postprocess=orthoseg.postprocess:main
Expand Down
42 changes: 40 additions & 2 deletions tests/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests for functionalities in orthoseg.train.
"""

from contextlib import nullcontext
from datetime import datetime
import os
from pathlib import Path
Expand Down Expand Up @@ -54,12 +55,49 @@ def test_2_load_images():
assert len(files) == 6


@pytest.mark.parametrize("exp_error", [(False)])
@pytest.mark.skipif(
"GITHUB_ACTIONS" in os.environ and os.name == "nt",
reason="crashes on github CI on windows",
)
@pytest.mark.order(after="test_1_init_testproject")
def test_3_train():
def test_3_validate(exp_error: bool):
# Load project config to init some vars.
config_path = footballfields_dir / "footballfields_train_test.ini"
conf.read_orthoseg_config(config_path)

# Init + cleanup result dirs
traindata_id_result = 2
training_dir = conf.dirs.getpath("training_dir")
training_id_dir = training_dir / f"{traindata_id_result:02d}"
if training_id_dir.exists():
shutil.rmtree(training_id_dir)
model_dir = conf.dirs.getpath("model_dir")
if model_dir.exists():
modelfile_paths = model_dir.glob(f"footballfields_{traindata_id_result:02d}_*")
for modelfile_path in modelfile_paths:
modelfile_path.unlink()

# Make sure the label files in version 01 are older than those in the label dir
# so a new model will be trained
label_01_path = training_dir / "01/footballfields_BEFL-2019_polygons.gpkg"
timestamp_old = datetime(year=2020, month=1, day=1).timestamp()
os.utime(label_01_path, (timestamp_old, timestamp_old))

if exp_error:
handler = pytest.raises(Exception)
else:
handler = nullcontext()
with handler:
orthoseg.validate(config_path=config_path)


@pytest.mark.skipif(
"GITHUB_ACTIONS" in os.environ and os.name == "nt",
reason="crashes on github CI on windows",
)
@pytest.mark.order(after="test_1_init_testproject")
def test_4_train():
# Load project config to init some vars.
config_path = footballfields_dir / "footballfields_train_test.ini"
conf.read_orthoseg_config(config_path)
Expand Down Expand Up @@ -105,7 +143,7 @@ def test_3_train():
reason="crashes on github CI on windows",
)
@pytest.mark.order(after="test_2_load_images")
def test_4_predict():
def test_5_predict():
# Load project config to init some vars.
config_path = footballfields_dir / "footballfields_BEFL-2019_test.ini"
conf.read_orthoseg_config(config_path)
Expand Down
Loading