In [1]:
# Given the values of a single peakon on a coarse grid, this DNN finds the
# values of the peakon on a finer grid.
# Essentially, given the values shown by the @ symbols, it will output its 
# predictions for the values of the . symbols (and the four @ symbols bounding
# them).
#
#
# @   @   @   @   @   @
#
#
#
# @   @   @   @   @   @
#
#
#
# @   @   @...@   @   @
#         .....
#         .....
#         .....
# @   @   @...@   @   @
#
#
#
# @   @   @   @   @   @
#
#
#
# @   @   @   @   @   @

In [2]:
%run project_base.ipynb

In [3]:
# Set to False to use existing weights for the neural network.
# Set to True to continue training the weights from where we left off last time.
TRAIN = True

In [4]:
def gen_one_data():
    """Generates a random feature/label combination.
    
    Here the feature is the solution on a coarse grid and the label is the
    solution on a fine grid."""
    
    # Random solution to the CH equation
    c = np.random.uniform(0, 10)
    peakon = Peakon(c=c)
    # Random location
    t = np.random.uniform(0, 40)
    x = c * t + np.random.uniform(-5, 5)
    
    # Grids at the location
    cg = coarse_grid((t, x))
    fg = fine_grid((t, x))
    # Features: the solution on the coarse grid
    X = peakon.on_grid(cg)
    # Labels: the solution on the fine grid
    y = peakon.on_grid(fg)
    
    return X, y

class BilinearInterp(BilinearInterpBase):
    @classmethod
    def _predict(cls, Xi):
        returnval = []
        # Translation doesn't matter at this point so WLOG the fine grid is
        # around 0, 0. (cls._interp makes the same assumption; these assumptions
        # must be consistent)
        for point in fine_grid((0, 0)):
            returnval.append(cls._interp(Xi, point))
        return returnval

In [5]:
model_dir = './saved_models/first_reg/'

# DNN hyperparameters
hidden_units = [1000] * 20
logits = 121  # = (_fine_grid_fineness.t + 1) * (_fine_grid_fineness.x + 1)
              # i.e. the number of fine grid points.
drop_rate = 0.4
batch_size = 100
batch_reuse = 1  # See BatchData for an explanation of batch reuse.
steps = 20000
log_steps = 1000  # How many steps to print the current loss.

for layer in hidden_units:
    model_dir += '{}_'.format(layer)
model_dir += '{}_'.format(logits)

model_dir += 'D0{}_'.format(int(drop_rate * 10))
model_dir += 'BS{}_'.format(batch_size)
model_dir += 'BR{}'.format(batch_reuse)

k_init = tfi.truncated_normal(mean=0, stddev=0.06)
model = Sequential()
for units in hidden_units:
    model.add(tfla.Dense(units=units,
                         activation=tf.nn.relu,
                         kernel_initializer=k_init))
    model.add_train(tfla.Dropout(rate=drop_rate))
model.add(tfla.Dense(units=logits))
model.set_kwargs(model_dir=model_dir,
                 config=tfe.RunConfig(log_step_count_steps=log_steps))

In [6]:
dnn = model.compile()

INFO:tensorflow:Using config: {'_model_dir': './saved_models/first_reg/1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_121_D04_BS100_BR1', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1000, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fecfc816898>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [7]:
if TRAIN:
    train_input_fn = BatchData(gen_one_data, batch_size, batch_reuse)
    dnn.train(input_fn=train_input_fn, max_steps=steps)

INFO:tensorflow:Skipping training since max_steps has already saved.


In [8]:
testing_data = gen_test_data(gen_one_data)

In [9]:
test_regressor(testing_data, dnn)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./saved_models/first_reg/1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_1000_121_D04_BS100_BR1/model.ckpt-20000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.


{'prediction': array([[0.14926415, 0.15409765, 0.16087928, 0.15257994, 0.15464315,
         0.15393828, 0.141984  , 0.16744604, 0.15511741, 0.13943319,
         0.16450201, 0.15790605, 0.15299832, 0.13827906, 0.18118531,
         0.16429103, 0.17651638, 0.15372189, 0.15470735, 0.17310521,
         0.17528853, 0.1794029 , 0.16731911, 0.16218502, 0.14034962,
         0.1649023 , 0.14658231, 0.17577329, 0.15583617, 0.15748532,
         0.17350795, 0.17501876, 0.15382826, 0.15524626, 0.16050107,
         0.13932014, 0.16038425, 0.15267   , 0.16207435, 0.17448118,
         0.15856459, 0.15161744, 0.14700493, 0.18476644, 0.16917386,
         0.14209846, 0.18142125, 0.14702185, 0.13393832, 0.13556312,
         0.14077205, 0.1709107 , 0.15446562, 0.14837456, 0.17186993,
         0.13683722, 0.14588793, 0.15285126, 0.16634555, 0.16339104,
         0.14849687, 0.17550542, 0.15128431, 0.19399827, 0.16663549,
         0.15932268, 0.1436305 , 0.15918793, 0.15121963, 0.16070648,
         0.14088   ,

In [10]:
test_regressor(testing_data, BilinearInterp)

{'prediction': array([[1.12793653, 1.12912279, 1.13030906, 1.13149532, 1.13268158,
         1.13386784, 1.1350541 , 1.13624036, 1.13742662, 1.13861288,
         1.24656265, 1.12556157, 1.12674534, 1.1279291 , 1.12911287,
         1.13029663, 1.13148039, 1.13266416, 1.13384792, 1.13503168,
         1.13621545, 1.24393792, 1.12318662, 1.12436788, 1.12554915,
         1.12673041, 1.12791168, 1.12909295, 1.13027421, 1.13145548,
         1.13263674, 1.13381801, 1.24131318, 1.12081166, 1.12199043,
         1.1231692 , 1.12434796, 1.12552673, 1.1267055 , 1.12788427,
         1.12906303, 1.1302418 , 1.13142057, 1.23868845, 1.1184367 ,
         1.11961297, 1.12078924, 1.12196551, 1.12314178, 1.12431805,
         1.12549432, 1.12667059, 1.12784686, 1.12902313, 1.23606372,
         1.11606174, 1.11723552, 1.11840929, 1.11958306, 1.12075683,
         1.12193061, 1.12310438, 1.12427815, 1.12545192, 1.1266257 ,
         1.23343898, 1.11368679, 1.11485806, 1.11602934, 1.11720061,
         1.11837188,