In [1]:
%load_ext autoreload
%autoreload 2

In [46]:
%%time
from nngp import NN
from dataloader.rosen import RosenData 
import numpy as np
import tensorflow as tf
from sklearn.metrics import mean_squared_error as mse


CPU times: user 21 µs, sys: 2 µs, total: 23 µs
Wall time: 29.6 µs


In [32]:
config = {
    'random_state': 4623457,
    'n_dim': 10,
    'n_train': 200,
    'n_test': 200,
    'n_pool': 1000,
    'layers': [128, 64, 32],
    'update_sample_size': 10,
    'al_iterations': 10
}



In [12]:
np.random.seed(config['random_state'])

In [33]:
X_train, y_train, _, _, X_test, y_test, X_pool, y_pool = RosenData(
    config['n_train'], 0, config['n_test'], config['n_pool'], config['n_dim']
).dataset()

print('shapes:', X_train.shape, y_train.shape)
print('shapes:', X_test.shape, y_test.shape)
print('shapes:', X_pool.shape, y_pool.shape)

shapes: (200, 10) (200, 1)
shapes: (200, 10) (200, 1)
shapes: (1000, 10) (1000, 1)


In [35]:
# tf.reset_default_graph()
# nn = NN(
#     ndim = config['ndim'],
#     random_state = config['random_state'],
#     layers = config['layers']
# )
tf.reset_default_graph()
from model.mlp import MLP

nn = MLP(
    ndim = config['n_dim'],
    random_state = config['random_state'],
    layers = config['layers']
)

In [36]:
try:
    sess.close()
except:
    pass
# a setting for my cluster; ignore it
session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True

# global init
init = tf.global_variables_initializer()
# saver = tf.train.Saver()
sess = tf.Session(config=session_config)
sess.run(init)

In [None]:
from uncertainty_estimator.mcdue import MCDUE

estimator = MCDUE(nn)

uncertainties = estimator.estimate(sess, X_pool)
uncertainties

In [54]:
from sample_selector.eager import EagerSampleSelector

sampler = EagerSampleSelector()

In [43]:
def print_shapes(note, *sets):
    print(note)
    for x, y in sets:
        print("shapes:", x.shape, y.shape)
iteration = 0
note = f'[{iteration}] BEFORE:'
print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    

[0] BEFORE:
shapes: (200, 10) (200, 1)
shapes: (200, 10) (200, 1)
shapes: (1000, 10) (1000, 1)


In [47]:
nn.train(sess, X_train, y_train, X_test, y_test, X_test, y_test)

rmses = [np.sqrt(mse(nn.predict(sess, data=X_test), y_test))]

for al_iteration in range(1, config['al_iterations']+1):
    note = f'[{iteration}] BEFORE:'
    print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    uncertainties = estimator.estimate(sess, X_pool)
    X_train, y_train, X_pool, y_pool = sampler.update_sets(
        X_train, y_train, X_pool, y_pool, uncertainties, config['update_sample_size']
    )
    
    note = f'[{iteration}] AFTER:'
    print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    nn.train(sess, X_train, y_train, X_test, y_test, X_test, y_test)
    rmses.append(np.sqrt(mse(nn.predict(sess, data=X_test), y_test)))

[100] RMSE train:70.441 test:65.314 val:65.314 patience:3
[200] RMSE train:70.386 test:65.284 val:65.284 patience:3
[300] RMSE train:70.327 test:65.252 val:65.252 patience:3
[400] RMSE train:70.267 test:65.220 val:65.220 patience:3
[500] RMSE train:70.204 test:65.188 val:65.188 patience:3
[600] RMSE train:70.139 test:65.155 val:65.155 patience:3
[700] RMSE train:70.072 test:65.122 val:65.122 patience:3
[800] RMSE train:70.004 test:65.089 val:65.089 patience:3
[900] RMSE train:69.934 test:65.056 val:65.056 patience:3
[1000] RMSE train:69.862 test:65.023 val:65.023 patience:3
[1100] RMSE train:69.788 test:64.990 val:64.990 patience:3
[1200] RMSE train:69.714 test:64.958 val:64.958 patience:3
[1300] RMSE train:69.638 test:64.926 val:64.926 patience:3
[1400] RMSE train:69.561 test:64.895 val:64.895 patience:3
[1500] RMSE train:69.483 test:64.865 val:64.865 patience:3
[1600] RMSE train:69.404 test:64.835 val:64.835 patience:3
[1700] RMSE train:69.324 test:64.806 val:64.806 patience:3
[1800]

[100] RMSE train:5.430 test:25.285 val:25.285 patience:3
[200] RMSE train:5.071 test:25.290 val:25.290 patience:2
[300] RMSE train:4.790 test:25.273 val:25.273 patience:3
[400] RMSE train:4.547 test:25.256 val:25.256 patience:3
[500] RMSE train:4.330 test:25.258 val:25.258 patience:2
[600] RMSE train:4.133 test:25.280 val:25.280 patience:1
[700] RMSE train:3.944 test:25.282 val:25.282 patience:0
No patience left at epoch 700. Early stopping.
[0] BEFORE:
shapes: (250, 10) (250, 1)
shapes: (200, 10) (200, 1)
shapes: (950, 10) (950, 1)
[0] AFTER:
shapes: (260, 10) (260, 1)
shapes: (200, 10) (200, 1)
shapes: (940, 10) (940, 1)
[100] RMSE train:6.875 test:25.353 val:25.353 patience:3
[200] RMSE train:6.375 test:25.435 val:25.435 patience:2
[300] RMSE train:6.016 test:25.483 val:25.483 patience:1
[400] RMSE train:5.703 test:25.541 val:25.541 patience:0
No patience left at epoch 400. Early stopping.
[0] BEFORE:
shapes: (260, 10) (260, 1)
shapes: (200, 10) (200, 1)
shapes: (940, 10) (940, 1)
[

In [48]:
rmses

[26.43332185937285,
 27.214255117293217,
 26.89824332628774,
 25.825581002124924,
 25.53699742709461,
 25.281566323268205,
 25.540511073203586,
 25.15608406195817,
 25.608140684649683,
 26.145912573814993,
 26.50371456122752]

In [49]:
nn.predict(sess,
           data = X_test[:3])

array([[215.42189],
       [151.62822],
       [179.41277]], dtype=float32)

In [50]:
y_test[:3]

array([[195.80091804],
       [145.94501192],
       [181.60245989]])

In [53]:
estimator.estimate(sess, X_pool)

array([207.82082752, 219.78968085, 135.43199557, 195.23482891,
       224.11474786, 190.80263021, 201.36593892, 187.02425998,
       106.95878852, 200.70247504, 169.94889073, 131.0161973 ,
       195.04326917, 189.0711145 , 239.01674489, 157.56752344,
       141.83347819, 152.5048312 , 163.04307004, 141.57128053,
       259.88909194, 182.58735583, 189.42855492, 155.4397427 ,
       144.45599738, 215.81614216, 207.56312347, 132.13015647,
       155.42927544, 195.0086972 , 230.77147876, 249.00300527,
       158.95081354, 156.65555852, 242.02261296,  92.98318906,
       141.08616375, 159.24173742, 209.47220512, 170.0789674 ,
       214.21550222, 227.90919298, 111.49748759, 152.41449545,
       171.21943916, 241.44919499, 121.42658073, 124.40178762,
       110.70722361, 147.05321115, 165.15329633, 224.44386446,
       117.45615713, 188.24621186, 184.65992492, 188.73014071,
       239.65061939, 167.79999049, 168.27594999, 153.89520212,
       136.24349423, 212.00660087, 171.03682667, 133.70