In [1]:
%load_ext autoreload
%autoreload 2

In [46]:
from nngp import NN
from dataloader.rosen import RosenData 
import numpy as np
import tensorflow as tf
from sklearn.metrics import mean_squared_error as mse


CPU times: user 21 µs, sys: 2 µs, total: 23 µs
Wall time: 29.6 µs


In [32]:
config = {
    'random_state': 4623457,
    'n_dim': 10,
    'n_train': 200,
    'n_test': 200,
    'n_pool': 1000,
    'layers': [128, 64, 32],
    'update_sample_size': 10,
    'al_iterations': 10
}



In [12]:
np.random.seed(config['random_state'])

In [33]:
X_train, y_train, _, _, X_test, y_test, X_pool, y_pool = RosenData(
    config['n_train'], 0, config['n_test'], config['n_pool'], config['n_dim']
).dataset()

print('shapes:', X_train.shape, y_train.shape)
print('shapes:', X_test.shape, y_test.shape)
print('shapes:', X_pool.shape, y_pool.shape)

shapes: (200, 10) (200, 1)
shapes: (200, 10) (200, 1)
shapes: (1000, 10) (1000, 1)


In [35]:
# tf.reset_default_graph()
# nn = NN(
#     ndim = config['ndim'],
#     random_state = config['random_state'],
#     layers = config['layers']
# )
tf.reset_default_graph()
from model.mlp import MLP

nn = MLP(
    ndim = config['n_dim'],
    random_state = config['random_state'],
    layers = config['layers']
)

In [36]:
try:
    sess.close()
except:
    pass
# a setting for my cluster; ignore it
session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True

# global init
init = tf.global_variables_initializer()
# saver = tf.train.Saver()
sess = tf.Session(config=session_config)
sess.run(init)

In [None]:
from uncertainty_estimator.mcdue import MCDUE

estimator = MCDUE(nn)

uncertainties = estimator.estimate(sess, X_pool)
uncertainties

In [61]:
from sample_selector.eager import EagerSampleSelector
from oracle.identity import IdentityOracle

oracle = IdentityOracle(y_pool)
sampler = EagerSampleSelector(oracle)

In [59]:
def print_shapes(note, *sets):
    print(note)
    for x, y in sets:
        print("shapes:", x.shape, y.shape)
iteration = 0
note = f'[{iteration}] BEFORE:'
print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    

[0] BEFORE:
shapes: (300, 10) (300, 1)
shapes: (200, 10) (200, 1)
shapes: (900, 10) (900, 1)


In [63]:
nn.train(sess, X_train, y_train, X_test, y_test, X_test, y_test)

rmses = [np.sqrt(mse(nn.predict(sess, data=X_test), y_test))]


for al_iteration in range(1, config['al_iterations']+1):
    note = f'[{al_iteration}] BEFORE:'
    print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    # update pool
    uncertainties = estimator.estimate(sess, X_pool)
    X_train, y_train, X_pool = sampler.update_sets(
        X_train, y_train, X_pool, uncertainties, config['update_sample_size']
    )
    
    note = f'[{al_iteration}] AFTER:'
    print_shapes(note, (X_train, y_train), (X_test, y_test), (X_pool, y_pool))
    
    # retrain net
    nn.train(sess, X_train, y_train, X_test, y_test, X_test, y_test)
    rmses.append(np.sqrt(mse(nn.predict(sess, data=X_test), y_test)))

[100] RMSE train:2.790 test:26.594 val:26.594 patience:3
[200] RMSE train:2.702 test:26.603 val:26.603 patience:2
[300] RMSE train:2.616 test:26.608 val:26.608 patience:1
[400] RMSE train:2.532 test:26.626 val:26.626 patience:0
No patience left at epoch 400. Early stopping.
[0] BEFORE:
shapes: (300, 10) (300, 1)
shapes: (200, 10) (200, 1)
shapes: (900, 10) (900, 1)
[0] AFTER:
shapes: (310, 10) (310, 1)
shapes: (200, 10) (200, 1)
shapes: (890, 10) (900, 1)
[100] RMSE train:4.460 test:26.462 val:26.462 patience:3
[200] RMSE train:3.954 test:26.494 val:26.494 patience:2
[300] RMSE train:3.608 test:26.501 val:26.501 patience:1
[400] RMSE train:3.355 test:26.543 val:26.543 patience:0
No patience left at epoch 400. Early stopping.
[0] BEFORE:
shapes: (310, 10) (310, 1)
shapes: (200, 10) (200, 1)
shapes: (890, 10) (900, 1)
[0] AFTER:
shapes: (320, 10) (320, 1)
shapes: (200, 10) (200, 1)
shapes: (880, 10) (900, 1)
[100] RMSE train:3.918 test:26.487 val:26.487 patience:3
[200] RMSE train:3.517 

In [64]:
rmses

[26.62585574625964,
 26.542743277506155,
 26.36098988366554,
 26.103252418181214,
 25.545855594415663,
 25.274810861329673,
 25.306115159302863,
 25.183696802490022,
 25.535809766291223,
 25.363881511933776,
 25.312211025648406]

In [49]:
nn.predict(sess,
           data = X_test[:3])

array([[215.42189],
       [151.62822],
       [179.41277]], dtype=float32)

In [50]:
y_test[:3]

array([[195.80091804],
       [145.94501192],
       [181.60245989]])

In [53]:
estimator.estimate(sess, X_pool)

array([207.82082752, 219.78968085, 135.43199557, 195.23482891,
       224.11474786, 190.80263021, 201.36593892, 187.02425998,
       106.95878852, 200.70247504, 169.94889073, 131.0161973 ,
       195.04326917, 189.0711145 , 239.01674489, 157.56752344,
       141.83347819, 152.5048312 , 163.04307004, 141.57128053,
       259.88909194, 182.58735583, 189.42855492, 155.4397427 ,
       144.45599738, 215.81614216, 207.56312347, 132.13015647,
       155.42927544, 195.0086972 , 230.77147876, 249.00300527,
       158.95081354, 156.65555852, 242.02261296,  92.98318906,
       141.08616375, 159.24173742, 209.47220512, 170.0789674 ,
       214.21550222, 227.90919298, 111.49748759, 152.41449545,
       171.21943916, 241.44919499, 121.42658073, 124.40178762,
       110.70722361, 147.05321115, 165.15329633, 224.44386446,
       117.45615713, 188.24621186, 184.65992492, 188.73014071,
       239.65061939, 167.79999049, 168.27594999, 153.89520212,
       136.24349423, 212.00660087, 171.03682667, 133.70