forked from Kitware/Danesfield
-
Notifications
You must be signed in to change notification settings - Fork 0
/
kwsemantic_segment.py
90 lines (72 loc) · 3.39 KB
/
kwsemantic_segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python
###############################################################################
# Copyright Kitware Inc. and Contributors
# Distributed under the Apache License, 2.0 (apache.org/licenses/LICENSE-2.0)
# See accompanying Copyright.txt and LICENSE files for details
###############################################################################
import logging
import os
import sys
import numpy as np
from osgeo import gdal
import argparse
import json
from danesfield.segmentation.semantic.utils.utils import update_config
from danesfield.segmentation.semantic.tasks.seval import Evaluator
from danesfield.segmentation.semantic.utils.config import Config
# Need to append to sys.path here as the pretrained model includes an
# import statement for "models" rather than
# "danesfield.segmentation.semantic.models"
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../danesfield/segmentation/semantic"))
def predict(rgbpath, dsmpath, dtmpath, msipath, outdir, outfname, config):
img_data = np.transpose(gdal.Open(rgbpath).ReadAsArray(), (1, 2, 0))
dsm_data = gdal.Open(dsmpath).ReadAsArray()
dtm_data = gdal.Open(dtmpath).ReadAsArray()
ndsm_data = dsm_data - dtm_data
ndsm_data[ndsm_data < 0] = 0
ndsm_data[ndsm_data > 40] = 40
ndsm_data = ndsm_data/40*255
msi_image = np.transpose(gdal.Open(msipath).ReadAsArray(), (1, 2, 0))
red_map = msi_image[:, :, 4].astype(np.float)
nir_map = msi_image[:, :, 6].astype(np.float)
ndvi = (nir_map - red_map)/(nir_map + red_map + 1e-7)
ndvi[ndvi < 0] = 0
ndvi[ndvi > 1] = 1
ndvi_data = ndvi*255.0
input_data = np.moveaxis(np.dstack([img_data, ndsm_data, ndvi_data])/255, -1, 0)
input_data = input_data.astype(np.float32)
input_data = (input_data - 0.5)*2
keval = Evaluator(config)
keval.onepredict(input_data, dsmpath, outdir, outfname)
def main(args):
parser = argparse.ArgumentParser(description='configuration for semantic segmentation task.')
parser.add_argument('config_path', help='configuration file path.')
parser.add_argument('pretrain_model_path', help='pretrained model file path.')
parser.add_argument('rgbpath', help='3-band 8-bit RGB image path')
parser.add_argument('dsmpath', help='1-band float DSM file path')
parser.add_argument('dtmpath', help='1-band float DTM file path')
parser.add_argument('msipath', help='8-band float MSI file path')
parser.add_argument('outdir', help='directory in which to write output files')
parser.add_argument('outfname', help='out filename for prediction probability and class mask')
args = parser.parse_args(args)
with open(args.config_path, 'r') as f:
cfg = json.load(f)
pretrain_model_path = args.pretrain_model_path
rgbpath = args.rgbpath
dsmpath = args.dsmpath
dtmpath = args.dtmpath
msipath = args.msipath
outfname = args.outfname
cfg['pretrain_model_path'] = pretrain_model_path
cfg['out_fname'] = outfname
config = Config(**cfg)
config = update_config(config, img_rows=2048, img_cols=2048, target_rows=2048,
target_cols=2048, num_channels=5)
predict(rgbpath, dsmpath, dtmpath, msipath, args.outdir, outfname, config)
if __name__ == "__main__":
try:
main(sys.argv[1:])
except Exception as e:
logging.exception(e)
sys.exit(1)