Skip to content

Commit

Permalink
config file includes all parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
tjvandal committed Oct 2, 2017
1 parent 76f34e5 commit 8566f06
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 49 deletions.
12 changes: 10 additions & 2 deletions config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ save_step: 1000
test_step: 50
dropout_prob: 0.0

[Experiment-1]
[Model-1]
data_dir: /gss_gpfs_scratch/vandal.t/deepsd-test/ppt_008_016/
model_name: ppt-008-016

[Experiment-2]
[Model-2]
data_dir: /gss_gpfs_scratch/vandal.t/deepsd-test/ppt_004_008/
model_name: ppt-004-008

[DeepSD]
model_name: ppt-004-016
low_resolution: 16
high_resolution: 4
upscale_factor: 2
80 changes: 48 additions & 32 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,33 @@
import prism
import utils

# model parameters
flags = tf.flags
flags.DEFINE_string('config_file', 'config.ini',
'Configuration file with [SRCNN], [Model-%], and [DeepSD] sections.')

# parse flags
FLAGS = flags.FLAGS
FLAGS._parse_flags()

config = ConfigParser.ConfigParser()
config.read(FLAGS.config_file)

PRISM_DIR = os.path.join(config.get('Paths', 'prism'), 'ppt', 'raw')

model_sections = [(s,int(s.split('-')[1])) for s in config.sections() if 'Model' in s]
model_sections.sort(key=lambda tup: tup[1])


LAYER_SIZES = [int(k) for k in config.get('SRCNN', 'layer_sizes').split(",")]
KERNEL_SIZES = [int(k) for k in config.get('SRCNN', 'kernel_sizes').split(",")]
UPSCALE_FACTOR = config.getint('DeepSD', 'upscale_factor')

CHECKPOINT_DIR = os.path.join(config.get('SRCNN', 'scratch'), "srcnn_%s_%s_%s" % ( '%s',
'-'.join([str(s) for s in LAYER_SIZES]),
'-'.join([str(s) for s in KERNEL_SIZES])))
CHECKPOINTS = [CHECKPOINT_DIR % config.get(m[0], 'model_name') for m in model_sections ]
DEEPSD_MODEL_NAME = config.get('DeepSD', 'model_name')

def get_graph_def():
with tf.Session() as sess:
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
Expand Down Expand Up @@ -50,7 +71,7 @@ def freeze_graph(model_folder, graph_name=None):
raise ValueError("Give me a graph_name")

# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
clear_devices = True

# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
Expand Down Expand Up @@ -142,22 +163,17 @@ def join_graphs(checkpoints, new_checkpoint):
# resize low-resolution
h = tf.shape(x)[1]
w = tf.shape(x)[2]
size = tf.stack([h*2, w*2])
size = tf.stack([h*UPSCALE_FACTOR, w*UPSCALE_FACTOR])
x = tf.image.resize_bilinear(x, size)

# join elevation and interpolated image
x = tf.concat([x, elv], axis=3)
graph_name = "_".join(os.path.basename(cpt.strip("/")).split("_")[1:4])

# load frozen graph with x as the input
print 'x', x
next_input = graph_name + '/x'
x = load_graph(os.path.join(cpt, 'frozen_model.pb'), graph_name, x=x)

for var in tf.global_variables():
print 'chkp', j, var.op.name
time.sleep(0.1)

with tf.Session() as sess:
summary_op = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(new_checkpoint, sess.graph)
Expand All @@ -170,37 +186,26 @@ def join_graphs(checkpoints, new_checkpoint):
print("%d ops in the final graph." % len(gd.node))

tf.reset_default_graph()
return output_graph

def main(frozen_graph, scale1=1., scale2=1./2, n_stacked=1, upscale_factor=2):
# read configuration file
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.ini')
config = ConfigParser.ConfigParser()
config.read(config_file)
return output_graph, x.name

year = 2015
def main(frozen_graph, output_node, year, scale1=1., n_stacked=1):
# read prism dataset
## resnet parameter will not re-interpolate X
dataset = prism.PrismSuperRes(os.path.join(config.get('Paths', 'prism'), 'ppt','raw'), year,
config.get('Paths', 'elevation'), model='srcnn')
dataset = prism.PrismSuperRes(PRISM_DIR, year, config.get('Paths', 'elevation'), model='srcnn')

X, elev, Y, lats, lons, times = dataset.make_test(scale1=scale1, scale2=scale2)
print X.shape, elev.shape, Y.shape
X, elev, Y, lats, lons, times = dataset.make_test(scale1=scale1, scale2=1./UPSCALE_FACTOR**2)
mask = (Y[0,:,:,0]+1)/(Y[0,:,:,0] + 1)
elev_hr = elev[0,:,:,0] # all the elevations are the same, remove some data from memory

# resize x
n, h, w, c = X.shape
print 'Y Shape', Y.shape
print 'X shape', X.shape

# get elevations at all 5 resolutions
elev_dict = {}
elevs = []
for i in range(n_stacked):
r = upscale_factor**i
r = UPSCALE_FACTOR**i
elev_dict[1./r] = cv2.resize(elev_hr, (0,0), fx=1./r, fy=1./r)
print 'elev shape', elev_dict[1./r].shape
elevs.append(tf.constant(elev_dict[1./r][np.newaxis, :, :, np.newaxis].astype(np.float32)))
elevs = elevs[::-1]

Expand All @@ -218,7 +223,7 @@ def main(frozen_graph, scale1=1., scale2=1./2, n_stacked=1, upscale_factor=2):
y, = tf.import_graph_def(
graph_def,
input_map=input_map,
return_elements=['ppt_008_016/prediction:0'],
return_elements=[output_node],
name='deepsd',
op_dict=None,
producer_op_list=None
Expand All @@ -241,6 +246,7 @@ def main(frozen_graph, scale1=1., scale2=1./2, n_stacked=1, upscale_factor=2):
xr.Dataset({'precip': precip}).to_netcdf("precip_%i_downscaled.nc" % year)

fig, axs = plt.subplots(3,1)
ymax = np.nanmax(Y)
axs = np.ravel(axs)
axs[0].imshow(Y[0,:,:,0], vmax=ymax)
axs[0].axis('off')
Expand All @@ -255,14 +261,24 @@ def main(frozen_graph, scale1=1., scale2=1./2, n_stacked=1, upscale_factor=2):
plt.close()

if __name__ == '__main__':
checkpoints = [
'scratch/srcnn_ppt_008_016_64-32-1_9-1-5/',
]
checkpoints = sorted(checkpoints)[::-1]
joined_checkpoint = os.path.join(os.path.dirname(checkpoints[0][:-1]), 'joined_008_016')
highest_resolution = 4
hr_resolution_km = config.getint('DeepSD', 'high_resolution')
lr_resolution_km = config.getint('DeepSD', 'low_resolution')
start = hr_resolution_km / highest_resolution
N = int((lr_resolution_km / hr_resolution_km)**(1./UPSCALE_FACTOR))

CHECKPOINTS = sorted(CHECKPOINTS)[::-1]
if len(CHECKPOINTS) != int(N):
raise ValueError

joined_checkpoint = os.path.join(os.path.dirname(CHECKPOINTS[0][:-1]), DEEPSD_MODEL_NAME)

if not os.path.exists(joined_checkpoint):
os.mkdir(joined_checkpoint)

#new_graph = join_graphs(checkpoints, joined_checkpoint)
new_graph, output_node = join_graphs(CHECKPOINTS, joined_checkpoint)
new_graph = os.path.join(joined_checkpoint, 'frozen_graph.pb')
main(new_graph, scale1=1./2, scale2=1./2)
year1 = config.getint('DataOptions', 'max_train_year')+1
yearlast = config.getint('DataOptions', 'max_year')
for y in range(year1, yearlast+1):
main(new_graph, output_node, y, scale1=start, n_stacked=N)
36 changes: 24 additions & 12 deletions prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
import ConfigParser
import utils

try:
from tfwriter import convert_to_tf
except ImportError:
print("Tensorflow not available")
import tensorflow as tf
from tfwriter import convert_to_tf


def recursive_mkdir(path):
split_dir = path.split("/")
Expand Down Expand Up @@ -299,10 +298,17 @@ def main_prism_tf(config, model='srcnn'):
minyear = int(config.get('DataOptions', 'min_year'))
maxyear = int(config.get('DataOptions', 'max_year'))
patch_size = int(config.get('SRCNN', 'training_input_size'))
hr_resolution_km = 4
scale2 = 1./2 # scale2 is relative to scale1
#for scale1 in [1./16, 1., 1./2, 1./4, 1./8]:
for scale1 in [1./8, 1./4, 1./2, 1.]:

highest_resolution = 4
hr_resolution_km = config.getint('DeepSD', 'high_resolution')
lr_resolution_km = config.getint('DeepSD', 'low_resolution')
upscale_factor = config.getint('DeepSD', 'upscale_factor')

start = hr_resolution_km / highest_resolution
N = int((lr_resolution_km / hr_resolution_km)**(1./upscale_factor))

scale2 = 1./upscale_factor # scale2 is relative to scale1
for scale1 in [start * scale2**i for i in range(N)]:
save_dir = os.path.join(config.get('Paths', 'scratch'),
'%s_%03i_%03i' % (var, hr_resolution_km/scale1,
hr_resolution_km/(scale1*scale2)))
Expand All @@ -316,21 +322,27 @@ def main_prism_tf(config, model='srcnn'):
print "Making patches or year:", y
tf_file = os.path.join(save_dir, 'train_%i.tfrecords' % y)
print tf_file
if 1: # not os.path.exists(tf_file):
if not os.path.exists(tf_file):
print "trying to make patches"
d.make_patches(tf_file, size=patch_size, stride=20, scale1=scale1, scale2=scale2)
else:
print "Building test set for year:", y
tf_file = os.path.join(save_dir, 'test_%i.tfrecords' % y)
print tf_file
if 1: #not os.path.exists(tf_file):
if not os.path.exists(tf_file):
d.make_tf_test(tf_file, scale1, scale2)


if __name__ == "__main__":
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.ini')
flags = tf.flags
flags.DEFINE_string('config_file', 'config.ini', 'Configuration file with [SRCNN] section.')

# parse flags
FLAGS = flags.FLAGS
FLAGS._parse_flags()

config = ConfigParser.ConfigParser()
config.read(config_file)
config.read(FLAGS.config_file)
data_dir = config.get('Paths', 'prism')
min_year = int(config.get('DataOptions', 'min_year'))
max_year = int(config.get('DataOptions', 'max_year'))
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
KEEP_PROB = 1. - float(config.get('SRCNN', 'dropout_prob'))

# where to save and get data
DATA_DIR = config.get('Experiment-%s' % FLAGS.experiment_number, 'data_dir')
data_name = os.path.basename(DATA_DIR.strip("/"))
DATA_DIR = config.get('Model-%s' % FLAGS.experiment_number, 'data_dir')
MODEL_NAME = config.get('Model-%s' % FLAGS.experiment_number, 'model_name')
timestamp = str(int(time.time()))
curr_time = dt.datetime.now()

SAVE_DIR = os.path.join(config.get('SRCNN', 'scratch'), "srcnn_%s_%s_%s" % ( data_name,
SAVE_DIR = os.path.join(config.get('SRCNN', 'scratch'), "srcnn_%s_%s_%s" % ( MODEL_NAME,
'-'.join([str(s) for s in LAYER_SIZES]),
'-'.join([str(s) for s in KERNEL_SIZES])))

Expand Down

0 comments on commit 8566f06

Please sign in to comment.