From 8f4cf9c2265917d121ebebde0dce8f636160c1e2 Mon Sep 17 00:00:00 2001 From: thjashin Date: Tue, 28 May 2019 00:03:09 +0800 Subject: [PATCH] add classification code --- classification.py | 147 +++++++++++++++++++ classification/__init__.py | 6 + classification/fbnn.py | 100 +++++++++++++ classification/gpnet.py | 293 +++++++++++++++++++++++++++++++++++++ classification/svgp.py | 57 ++++++++ kernel/__init__.py | 2 + kernel/conv.py | 155 ++++++++++++++++++++ kernel/elementwise.py | 63 ++++++++ kernel/resnet.py | 125 ++++++++++++++++ 9 files changed, 948 insertions(+) create mode 100644 classification.py create mode 100644 classification/__init__.py create mode 100644 classification/fbnn.py create mode 100644 classification/gpnet.py create mode 100644 classification/svgp.py create mode 100644 kernel/__init__.py create mode 100644 kernel/conv.py create mode 100644 kernel/elementwise.py create mode 100644 kernel/resnet.py diff --git a/classification.py b/classification.py new file mode 100644 index 0000000..c0741a8 --- /dev/null +++ b/classification.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +from __future__ import division +import os +from collections import namedtuple + +import tensorflow as tf +import numpy as np + +from kernel.resnet import ResnetKernel +from kernel.elementwise import ReLUKernel +from utils.log import setup_logger +from utils.data import load_mnist_realval, load_cifar10 +from classification import svgp, gpnet, gpnet_nonconj, fbnn + + +FLAGS = tf.flags.FLAGS +tf.flags.DEFINE_string("dataset", "mnist", "Dataset.") +tf.flags.DEFINE_string("method", "gpnet_nonconj", """Inference method.""") +tf.flags.DEFINE_string("net", "tangent", "Inference network.") +tf.flags.DEFINE_integer("batch_size", 128, """Total batch size.""") +tf.flags.DEFINE_float("learning_rate", 3e-4, """Learning rate.""") +tf.flags.DEFINE_integer("n_inducing", 100, """Number of inducing points.""") +tf.flags.DEFINE_string("measure", "train", "Measurement set.") +tf.flags.DEFINE_float("hyper_rate", 0, "Hyperparameter update rate.") +tf.flags.DEFINE_integer("block_size", 2, "number of blocks for each size.") +tf.flags.DEFINE_float("beta0", 0.01, """Initial beta value.""") +tf.flags.DEFINE_float("gamma", 0.1, """Beta schedule.""") +tf.flags.DEFINE_integer("n_iters", 10000, """Number of training iterations.""") +tf.flags.DEFINE_string("note", "", "Note for random experiments.") + + +def main(): + flag_values = [ + ("method", FLAGS.method), + ("net", FLAGS.net), + ("inducing", FLAGS.n_inducing), + ("beta0", FLAGS.beta0), + ("gamma", FLAGS.gamma), + ("niter", FLAGS.n_iters), + ("bs", FLAGS.batch_size // 2), + ("m", FLAGS.batch_size // 2), + ("lr", FLAGS.learning_rate), + ("measure", FLAGS.measure), + ("hyper_rate", FLAGS.hyper_rate), + ("block", FLAGS.block_size), + ("note", FLAGS.note), + ] + flag_str = "$".join(["@".join([i[0], str(i[1])]) for i in flag_values]) + result_path = os.path.join( + "results", "classification", FLAGS.dataset, flag_str) + logger = setup_logger("classification", __file__, result_path, + filename="log") + + np.random.seed(1234) + tf.set_random_seed(1234) + + # Load MNIST + if FLAGS.dataset == "mnist": + train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist_realval( + dtype=np.float64) + train_x = np.vstack([train_x, valid_x]) + train_y = np.vstack([train_y, valid_y]) + input_shape = [1, 28, 28] + elif FLAGS.dataset == "cifar10": + train_x, train_y, test_x, test_y = load_cifar10( + dtype=np.float64) + input_shape = [3, 32, 32] + else: + raise NotImplementedError() + + train_x = 2 * train_x - 1 + test_x = 2 * test_x - 1 + + train = tf.data.Dataset.from_tensor_slices((train_x, train_y)) + test = tf.data.Dataset.from_tensor_slices((test_x, test_y)) + train = train.shuffle(buffer_size=1000).batch( + FLAGS.batch_size // 2).repeat() + test = test.batch(FLAGS.batch_size * 4) + + if FLAGS.measure == "test_x": + measure = tf.data.Dataset.from_tensor_slices(test_x) + else: + measure = tf.data.Dataset.from_tensor_slices(train_x) + measure = measure.shuffle(buffer_size=1000).batch( + FLAGS.batch_size // 2).repeat() + measure_iterator = measure.make_one_shot_iterator() + measure_batch = measure_iterator.get_next() + + handle = tf.placeholder(tf.string, shape=[]) + iterator = tf.data.Iterator.from_string_handle( + handle, train.output_types, train.output_shapes) + next_batch = iterator.get_next() + + train_iterator = train.make_one_shot_iterator() + test_iterator = test.make_initializable_iterator() + + sess = tf.Session() + + train_handle = sess.run(train_iterator.string_handle()) + test_handle = sess.run(test_iterator.string_handle()) + + Data = namedtuple("Data", [ + "next_batch", + "measure_batch", + "handle", + "train_handle", + "test_handle", + "test_iterator", + "train_x", + "train_y"]) + data = Data(next_batch, measure_batch, handle, train_handle, test_handle, + test_iterator, train_x, train_y) + + block_sizes = [FLAGS.block_size] * 3 + block_strides = [1, 2, 2] + with tf.variable_scope("prior"): + resnet_kern = ResnetKernel( + input_shape=input_shape, + block_sizes=block_sizes, + block_strides=block_strides, + kernel_size=3, + recurse_kern=ReLUKernel(), + var_weight=1., + var_bias=0., + conv_stride=1, + data_format="NCHW", + dtype=tf.float64, + ) + + sess.run(tf.variables_initializer(tf.trainable_variables("prior"))) + + # SVGP + if FLAGS.method == "svgp": + svgp(logger, sess, data, resnet_kern) + elif FLAGS.method == "gpnet": + gpnet(logger, sess, data, resnet_kern, dtype=tf.float64) + elif FLAGS.method == "gpnet_nonconj": + gpnet_nonconj(logger, sess, data, resnet_kern, dtype=tf.float64) + elif FLAGS.method == "fbnn": + fbnn(logger, sess, data, resnet_kern, dtype=tf.float64) + + +if __name__ == "__main__": + main() diff --git a/classification/__init__.py b/classification/__init__.py new file mode 100644 index 0000000..b8d1056 --- /dev/null +++ b/classification/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from .gpnet import gpnet, gpnet_nonconj +from .svgp import svgp +from .fbnn import fbnn diff --git a/classification/fbnn.py b/classification/fbnn.py new file mode 100644 index 0000000..2550bf9 --- /dev/null +++ b/classification/fbnn.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +from __future__ import division + +import tensorflow as tf +import numpy as np +import zhusuan as zs +import gpflowSlim as gpflow + +from bnn.multi_output.resnet import build_resnet +from utils.mvn import multivariate_normal_kl + + +FLAGS = tf.flags.FLAGS + + +def fbnn(logger, sess, data, kernel, dtype=tf.float64): + train_x = data.train_x + N, x_dim = train_x.shape + _, n_cls = data.train_y.shape + + bnn = build_resnet( + "bnn", + n_cls, + kernel.input_shape, + kernel.block_sizes, + kernel.block_strides, + data_format="NCHW", + dtype=dtype, + net=FLAGS.net) + + x, y = data.next_batch + # x_star: [bs_star, x_dim], x: [bs, x_dim] + x_star = data.measure_batch + # xx: [bs + bs_star, x_dim] + xx = tf.concat([x, x_star], axis=0) + + qff = bnn(xx) + # qf_mean: [n_cls, bs], qf_var: [n_cls, bs], f_pred: [n_cls, bs] + qf_mean, qf_var = bnn(x, full_cov=False) + f_pred = qf_mean + tf.sqrt(qf_var) * tf.random_normal(tf.shape(qf_mean), + dtype=dtype) + # y_pred: [bs] + y_pred = tf.argmax(qf_mean, axis=0, output_type=tf.int32) + # y_target: [bs] + y_target = tf.argmax(y, axis=1, output_type=tf.int32) + # acc: [] + acc = tf.reduce_mean(tf.to_float(tf.equal(y_pred, y_target))) + + # K_prior: [bs + bs_star, bs + bs_star] + K_prior = kernel.K(xx) + K_prior_tril = tf.cholesky( + K_prior + tf.eye(tf.shape(xx)[0], dtype=dtype) * gpflow.settings.jitter) + pff = zs.distributions.MultivariateNormalCholesky( + tf.zeros([n_cls, tf.shape(xx)[0]], dtype=dtype), + tf.tile(K_prior_tril[None, ...], [n_cls, 1, 1])) + + # likelihood term + f_term = -tf.nn.softmax_cross_entropy_with_logits( + labels=y, + logits=tf.matrix_transpose(f_pred)) + f_term = tf.reduce_sum(f_term) + + # kl term + kl_term = tf.reduce_sum(multivariate_normal_kl(qff, pff)) + + lower_bound = f_term - kl_term + + fbnn_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + bnn_var = tf.trainable_variables(scope="bnn") + infer_fbnn = fbnn_opt.minimize(-lower_bound, var_list=bnn_var) + print_freq = 1 + test_freq = 100 + sess.run(tf.variables_initializer(var_list=bnn_var + fbnn_opt.variables())) + train_stats = [] + for t in range(1, FLAGS.n_iters + 1): + _, train_ll, train_acc = sess.run( + [infer_fbnn, lower_bound, acc], + feed_dict={data.handle: data.train_handle}) + train_stats.append((train_ll, train_acc)) + + if t % print_freq == 0: + train_lls, train_accs = list(zip(*train_stats)) + logger.info("Iter {}, lower bound = {:.4f}, train acc = {:.4f}" + .format(t, np.mean(train_lls), np.mean(train_accs))) + train_stats = [] + + if t % test_freq == 0: + sess.run(data.test_iterator.initializer) + test_stats = [] + while True: + try: + test_stats.append( + sess.run(acc, + feed_dict={data.handle: data.test_handle})) + except tf.errors.OutOfRangeError: + break + logger.info(">> Test acc = {:.4f}".format(np.mean(test_stats))) diff --git a/classification/gpnet.py b/classification/gpnet.py new file mode 100644 index 0000000..92af653 --- /dev/null +++ b/classification/gpnet.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +from __future__ import division + +import tensorflow as tf +import numpy as np +import zhusuan as zs +import gpflowSlim as gpflow + +from bnn.multi_output.resnet import build_resnet +from utils.mvn import (multivariate_normal_cross_entropy, + multivariate_normal_entropy, + multivariate_normal_kl) + + +FLAGS = tf.flags.FLAGS + + +def gpnet(logger, sess, data, kernel, dtype=tf.float64): + train_x = data.train_x + N, x_dim = train_x.shape + _, n_cls = data.train_y.shape + + bnn_prev = build_resnet( + "bnn_prev", + n_cls, + kernel.input_shape, + kernel.block_sizes, + kernel.block_strides, + data_format="NCHW", + mvn=False, + dtype=dtype, + net=FLAGS.net) + bnn = build_resnet( + "bnn", + n_cls, + kernel.input_shape, + kernel.block_sizes, + kernel.block_strides, + data_format="NCHW", + dtype=dtype, + net=FLAGS.net) + + x, y = data.next_batch + n = tf.shape(x)[0] + # x_star: [bs_star, x_dim], x: [bs, x_dim] + x_star = data.measure_batch + beta = tf.placeholder(dtype, shape=[], name="beta") + # xx: [bs + bs_star, x_dim] + xx = tf.concat([x, x_star], axis=0) + + qf_star = bnn(x_star) + # qff_mean_prev: [n_cls, bs + bs_star] + # K_prox: [n_cls, bs + bs_star, bs + bs_star] + qff_mean_prev, K_prox = bnn_prev(xx) + # qf_mean: [n_cls, bs], qf_var: [n_cls, bs, bs] + qf_mean, qf_var = bnn(x, full_cov=False) + # y_pred: [bs] + y_pred = tf.argmax(qf_mean, axis=0, output_type=tf.int32) + # y_target: [bs] + y_target = tf.argmax(y, axis=1, output_type=tf.int32) + # acc: [] + acc = tf.reduce_mean(tf.to_float(tf.equal(y_pred, y_target))) + + bnn_ops = [] + bnn_prev_var = tf.trainable_variables(scope="bnn_prev") + bnn_var = tf.trainable_variables(scope="bnn") + for (prev, cur) in zip(bnn_prev_var, bnn_var): + bnn_ops.append(prev.assign(cur)) + bnn_op = tf.group(bnn_ops) + + with tf.variable_scope("likelihood"): + likelihood = gpflow.likelihoods.Gaussian(var=0.1) + likelihood_var = likelihood.variance + # K_prior: [bs + bs_star, bs + bs_star] + K_prior = kernel.K(xx) + tf.eye( + tf.shape(xx)[0], dtype=dtype) * gpflow.settings.jitter + + # K_sum_tril: [n_cls, bs + bs_star, bs + bs_star] + K_sum_tril = tf.cholesky(K_prox * beta + K_prior * (1 - beta)) + # K_sum_tril_inv: [n_cls, bs + bs_star, bs + bs_star] + K_sum_tril_inv = tf.matrix_triangular_solve( + K_sum_tril, + tf.tile(tf.eye(tf.shape(xx)[0], dtype=dtype)[None, ...], + [n_cls, 1, 1])) + # K_sum_inv: [n_cls, bs + bs_star, bs + bs_star] + K_sum_inv = tf.matmul(K_sum_tril_inv, K_sum_tril_inv, transpose_a=True) + # K_adapt: [n_cls, bs + bs_star, bs + bs_star] + K_adapt = tf.matmul(tf.tile(K_prior[None, ...], [n_cls, 1, 1]), + tf.matmul(K_sum_inv, K_prox)) + # mean_adapt: [n_cls, bs + bs_star, 1] + mean_adapt = (1 - beta) * tf.matmul( + K_adapt, tf.matrix_solve(K_prox, qff_mean_prev[..., None])) + # mean_n: [n_cls, bs, 1], mean_m: [n_cls, bs_star, 1] + mean_n, mean_m = mean_adapt[:, :n, :], mean_adapt[:, n:, :] + # Kn: [n_cls, bs, bs], + # Knm: [n_cls, bs, bs_star], + # Km: [n_cls, bs_star, bs_star] + Kn, Knm, Km = K_adapt[:, :n, :n], K_adapt[:, :n, n:], K_adapt[:, n:, n:] + + # Ky: [n_cls, bs, bs] + Ky = Kn + tf.eye(n, dtype=dtype) * likelihood_var / ( + N / tf.cast(n, dtype) * beta) + # Ky_tril: [n_cls, bs, bs] + Ky_tril = tf.cholesky(Ky) + + # y: [bs, n_cls] + # yf: [n_cls, bs] + yf = tf.cast(tf.matrix_transpose(y), dtype) + # mean_target: [n_cls, bs_star, 1] + mean_target = tf.matmul( + Knm, tf.cholesky_solve(Ky_tril, yf[..., None] - mean_n), + transpose_a=True) + mean_m + # mean_target: [n_cls, bs_star] + mean_target = tf.squeeze(mean_target, -1) + # K_target: [n_cls, bs_star, bs_star] + K_target = Km - tf.matmul(Knm, tf.cholesky_solve(Ky_tril, Knm), + transpose_a=True) + K_target_tril = tf.cholesky(K_target) + target_pf_star = zs.distributions.MultivariateNormalCholesky( + mean_target, K_target_tril) + + kl_obj = tf.reduce_sum( + multivariate_normal_kl(qf_star, target_pf_star, dtype=dtype)) + + # hyper-parameter update + Kn_prior = tf.tile(K_prior[None, :n, :n], [n_cls, 1, 1]) + pf = zs.distributions.MultivariateNormalCholesky( + tf.zeros([n_cls, n], dtype=dtype), tf.cholesky(Kn_prior)) + Kn_prox = K_prox[:, :n, :n] + qf_prev_mean = qff_mean_prev[:, :n] + qf_prev_var = tf.matrix_diag_part(Kn_prox) + qf_prev = zs.distributions.MultivariateNormalCholesky( + qf_prev_mean, tf.cholesky(Kn_prox)) + hyper_obj = tf.reduce_sum(likelihood.variational_expectations( + qf_prev_mean, qf_prev_var, yf)) - tf.reduce_sum( + multivariate_normal_kl(qf_prev, pf)) + + gpnet_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + infer_gpnet = gpnet_opt.minimize(kl_obj, var_list=bnn_var) + + hyper_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.hyper_rate) + gp_var_list = (tf.trainable_variables(scope="likelihood")) + if FLAGS.hyper_rate < 1e-8: + hyper_op = tf.no_op() + else: + hyper_op = hyper_opt.minimize(-hyper_obj, var_list=gp_var_list) + + test_freq = 50 + likelihood_var_list = tf.trainable_variables(scope="likelihood") + var_list = bnn_prev_var + bnn_var + likelihood_var_list + sess.run(tf.variables_initializer( + var_list=var_list + gpnet_opt.variables() + hyper_opt.variables())) + logger.info("prior gp var: {}".format(sess.run(likelihood_var))) + + for t in range(1, FLAGS.n_iters + 1): + beta_t = FLAGS.beta0 * 1. / (1. + FLAGS.gamma * np.sqrt(t - 1)) + _, _, train_obj, obs_var, train_acc = sess.run( + [infer_gpnet, hyper_op, kl_obj, likelihood_var, acc], + feed_dict={data.handle: data.train_handle, + beta: beta_t}) + sess.run(bnn_op) + logger.info("Iter {}, kl_obj = {}, obs_var = {}, train_acc = {}" + .format(t, train_obj, obs_var, train_acc)) + + if t % test_freq == 0: + sess.run(data.test_iterator.initializer) + test_stats = [] + while True: + try: + test_stats.append( + sess.run(acc, + feed_dict={data.handle: data.test_handle})) + except tf.errors.OutOfRangeError: + break + logger.info(">> Test acc = {:.4f}".format(np.mean(test_stats))) + + +def gpnet_nonconj(logger, sess, data, kernel, dtype=tf.float64): + train_x = data.train_x + N, x_dim = train_x.shape + _, n_cls = data.train_y.shape + + bnn_prev = build_resnet( + "bnn_prev", + n_cls, + kernel.input_shape, + kernel.block_sizes, + kernel.block_strides, + data_format="NCHW", + dtype=dtype, + net=FLAGS.net) + bnn = build_resnet( + "bnn", + n_cls, + kernel.input_shape, + kernel.block_sizes, + kernel.block_strides, + data_format="NCHW", + dtype=dtype, + net=FLAGS.net) + + x, y = data.next_batch + n = tf.shape(x)[0] + # x_star: [bs_star, x_dim], x: [bs, x_dim] + x_star = data.measure_batch + beta = tf.placeholder(dtype, shape=[], name="beta") + n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles") + # xx: [bs + bs_star, x_dim] + xx = tf.concat([x, x_star], axis=0) + + qff = bnn(xx) + qff_prev = bnn_prev(xx) + qf_mean, qf_var = bnn(x, full_cov=False) + f_pred = qf_mean + tf.sqrt(qf_var) * tf.random_normal(tf.shape(qf_mean), + dtype=dtype) + # y_pred: [bs] + y_pred = tf.argmax(qf_mean, axis=0, output_type=tf.int32) + # y_target: [bs] + y_target = tf.argmax(y, axis=1, output_type=tf.int32) + # acc: [] + acc = tf.reduce_mean(tf.to_float(tf.equal(y_pred, y_target))) + + bnn_ops = [] + bnn_prev_var = tf.trainable_variables(scope="bnn_prev") + bnn_var = tf.trainable_variables(scope="bnn") + for (prev, cur) in zip(bnn_prev_var, bnn_var): + bnn_ops.append(prev.assign(cur)) + bnn_op = tf.group(bnn_ops) + + # K_prior: [bs + bs_star, bs + bs_star] + K_prior = kernel.K(xx) + K_prior_tril = tf.cholesky( + K_prior + tf.eye(tf.shape(xx)[0], dtype=dtype) * gpflow.settings.jitter) + pff = zs.distributions.MultivariateNormalCholesky( + tf.zeros([n_cls, tf.shape(xx)[0]], dtype=dtype), + tf.tile(K_prior_tril[None, ...], [n_cls, 1, 1])) + + # # likelihood term + f_term = -tf.nn.softmax_cross_entropy_with_logits( + labels=y, + logits=tf.matrix_transpose(f_pred)) + f_term = tf.reduce_sum(f_term) + f_term *= N / tf.cast(n, tf.float64) * beta + + # prior term + prior_term = -beta * tf.reduce_sum( + multivariate_normal_cross_entropy(qff, pff)) + + # proximity term + prox_term = -(1 - beta) * tf.reduce_sum( + multivariate_normal_cross_entropy(qff, qff_prev)) + + # entropy term + entropy_term = tf.reduce_sum(multivariate_normal_entropy(qff)) + + lower_bound = f_term + prior_term + prox_term + entropy_term + + gpnet_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + infer_gpnet = gpnet_opt.minimize(-lower_bound, var_list=bnn_var) + print_freq = 1 + test_freq = 100 + var_list = bnn_prev_var + bnn_var + sess.run(tf.variables_initializer(var_list=var_list + gpnet_opt.variables())) + train_stats = [] + for t in range(1, FLAGS.n_iters + 1): + beta_t = FLAGS.beta0 * 1. / (1. + FLAGS.gamma * np.sqrt(t - 1)) + _, train_ll, train_acc = sess.run( + [infer_gpnet, lower_bound, acc], + feed_dict={data.handle: data.train_handle, + beta: beta_t}) + train_stats.append((train_ll, train_acc)) + sess.run(bnn_op) + + if t % print_freq == 0: + train_lls, train_accs = list(zip(*train_stats)) + logger.info("Iter {}, lower bound = {:.4f}, train acc = {:.4f}" + .format(t, np.mean(train_lls), np.mean(train_accs))) + train_stats = [] + + if t % test_freq == 0: + sess.run(data.test_iterator.initializer) + test_stats = [] + while True: + try: + test_stats.append( + sess.run(acc, + feed_dict={data.handle: data.test_handle})) + except tf.errors.OutOfRangeError: + break + logger.info(">> Test acc = {:.4f}".format(np.mean(test_stats))) diff --git a/classification/svgp.py b/classification/svgp.py new file mode 100644 index 0000000..f453ea9 --- /dev/null +++ b/classification/svgp.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +from __future__ import division + +import tensorflow as tf +import gpflowSlim as gpflow +import numpy as np + + +FLAGS = tf.flags.FLAGS + + +def svgp(logger, sess, data, resnet_kern): + train_x = data.train_x + N, x_dim = train_x.shape + _, n_cls = data.train_y.shape + x_batch, y_batch = data.next_batch + y_target = tf.argmax(y_batch, axis=1, output_type=tf.int32) + with tf.variable_scope("svgp"): + likelihood = gpflow.likelihoods.MultiClass(n_cls) + # inducing_points, _ = kmeans2(train_x, FLAGS.n_inducing, minit="points") + idx = np.random.permutation(N) + inducing_points = train_x[idx[:FLAGS.n_inducing]] + svgp = gpflow.models.SVGP( + x_batch, tf.cast(y_target[:, None], tf.float64), resnet_kern, likelihood, + Z=inducing_points, num_latent=n_cls, num_data=N) + obj = svgp.objective + + y_mean, y_var = svgp.predict_y(x_batch) + y_pred = tf.argmax(y_mean, axis=1, output_type=tf.int32) + acc = tf.reduce_mean(tf.to_float(tf.equal(y_pred, y_target))) + + svgp_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + var_list = tf.trainable_variables(scope="svgp/GPModel") + infer_svgp = svgp_opt.minimize(obj, var_list=var_list) + + test_freq = 50 + sess.run(tf.global_variables_initializer()) + for t in range(1, FLAGS.n_iters + 1): + _, train_ll = sess.run( + [infer_svgp, obj], + feed_dict={data.handle: data.train_handle}) + logger.info("Iter {}, lower bound = {:.4f}".format(t, -train_ll)) + + if t % test_freq == 0: + sess.run(data.test_iterator.initializer) + test_stats = [] + while True: + try: + test_stats.append( + sess.run(acc, + feed_dict={data.handle: data.test_handle})) + except tf.errors.OutOfRangeError: + break + logger.info(">> Test acc = {:.4f}".format(np.mean(test_stats))) diff --git a/kernel/__init__.py b/kernel/__init__.py new file mode 100644 index 0000000..faa18be --- /dev/null +++ b/kernel/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- diff --git a/kernel/conv.py b/kernel/conv.py new file mode 100644 index 0000000..9ca0586 --- /dev/null +++ b/kernel/conv.py @@ -0,0 +1,155 @@ +# Copyright 2019 Jiaxin Shi +# Copyright 2017 https://github.com/rhaps0dy/convnets-as-gps +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import gpflowSlim as gpflow +from typing import List +import numpy as np +import tensorflow as tf +import abc + +from .elementwise import ElementwiseKernel + + +class ConvKernelBase(gpflow.kernels.Kernel, metaclass=abc.ABCMeta): + "General kernel for deep networks" + def __init__(self, + input_shape: List[int], + block_sizes: List[int], + block_strides: List[int], + kernel_size: int, + recurse_kern: ElementwiseKernel, + var_weight: float = 1.0, + var_bias: float = 1.0, + conv_stride: int = 1, + active_dims: slice = None, + data_format: str = "NCHW", + input_type = None, + name: str = None, + dtype = tf.float32): + input_dim = np.prod(input_shape) + super(ConvKernelBase, self).__init__(input_dim, active_dims, name=name) + + self.input_shape = list(np.copy(input_shape)) + self.block_sizes = np.copy(block_sizes).astype(np.int32) + self.block_strides = np.copy(block_strides).astype(np.int32) + self.kernel_size = kernel_size + self.recurse_kern = recurse_kern + self.conv_stride = conv_stride + self.data_format = data_format + if input_type is None: + input_type = dtype + self.input_type = input_type + self.dtype = dtype + + self._var_weight = gpflow.params.Parameter( + var_weight, gpflow.transforms.positive, dtype=self.input_type, + name="var_weight") + self._var_bias = gpflow.params.Parameter( + var_bias, gpflow.transforms.positive, dtype=self.input_type, + name="var_bias") + + @property + def var_weight(self): + return self._var_weight.value + + @property + def var_bias(self): + return self._var_bias.value + + @gpflow.decors.name_scope() + def K(self, X, X2=None): + # Concatenate the covariance between X and X2 and their respective + # variances. Only 1 variance is needed if X2 is None. + if X.dtype != self.input_type or ( + X2 is not None and X2.dtype != self.input_type): + raise TypeError("Input dtypes are wrong: {} or {} are not {}" + .format(X.dtype, X2.dtype, self.input_type)) + if X2 is None: + N = N2 = tf.shape(X)[0] + var_z_list = [ + tf.reshape(tf.square(X), [N] + self.input_shape), + tf.reshape(X[:, None, :] * X, [N*N] + self.input_shape)] + + @gpflow.decors.name_scope("apply_recurse_kern_X_X") + def apply_recurse_kern(var_a_all, concat_outputs=True): + var_a_1 = var_a_all[:N] + var_a_cross = var_a_all[N:] + vz = [self.recurse_kern.Kdiag(var_a_1), + self.recurse_kern.K(var_a_cross, var_a_1, None)] + if concat_outputs: + return tf.concat(vz, axis=0) + return vz + + else: + N, N2 = tf.shape(X)[0], tf.shape(X2)[0] + var_z_list = [ + tf.reshape(tf.square(X), [N] + self.input_shape), + tf.reshape(tf.square(X2), [N2] + self.input_shape), + tf.reshape(X[:, None, :] * X2, [N*N2] + self.input_shape)] + cross_start = N + N2 + + @gpflow.decors.name_scope("apply_recurse_kern_X_X2") + def apply_recurse_kern(var_a_all, concat_outputs=True): + var_a_1 = var_a_all[:N] + var_a_2 = var_a_all[N:cross_start] + var_a_cross = var_a_all[cross_start:] + vz = [self.recurse_kern.Kdiag(var_a_1), + self.recurse_kern.Kdiag(var_a_2), + self.recurse_kern.K(var_a_cross, var_a_1, var_a_2)] + if concat_outputs: + return tf.concat(vz, axis=0) + return vz + inputs = tf.concat(var_z_list, axis=0) + if self.data_format == "NHWC": + # Transpose NCHW -> NHWC + inputs = tf.transpose(inputs, [0, 2, 3, 1]) + + if len(self.block_sizes) > 0: + # Define almost all the network + inputs = self.headless_network(inputs, apply_recurse_kern) + # Last nonlinearity before final dense layer + var_z_list = apply_recurse_kern(inputs, concat_outputs=False) + # averaging for the final dense layer + var_z_cross = tf.reshape(var_z_list[-1], [N, N2, -1]) + var_z_cross_last = tf.reduce_mean(var_z_cross, axis=2) + result = self.var_bias + self.var_weight * var_z_cross_last + return result + + @gpflow.decors.name_scope() + def Kdiag(self, X): + if X.dtype != self.input_type: + raise TypeError("Input dtype is wrong: {} is not {}" + .format(X.dtype, self.input_type)) + inputs = tf.reshape(tf.square(X), [-1] + self.input_shape) + if len(self.block_sizes) > 0: + inputs = self.headless_network(inputs, self.recurse_kern.Kdiag) + # Last dense layer + inputs = self.recurse_kern.Kdiag(inputs) + + all_except_first = np.arange(1, len(inputs.shape)) + var_z_last = tf.reduce_mean(inputs, axis=all_except_first) + result = self.var_bias + self.var_weight * var_z_last + return result + + @abc.abstractmethod + def headless_network(self, inputs, apply_recurse_kern): + """ + Apply the network that this kernel defines, except the last dense layer. + The last dense layer is different for K and Kdiag. + """ + raise NotImplementedError diff --git a/kernel/elementwise.py b/kernel/elementwise.py new file mode 100644 index 0000000..3302f9c --- /dev/null +++ b/kernel/elementwise.py @@ -0,0 +1,63 @@ +# Copyright 2019 Jiaxin Shi +# Copyright 2017 https://github.com/rhaps0dy/convnets-as-gps +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import tensorflow as tf +import numpy as np + + +class ElementwiseKernel(object): + def K(self, cov, var1, var2=None): + raise NotImplementedError + + def Kdiag(self, var): + raise NotImplementedError + + def nlin(self, x): + """ + The nonlinearity that this is computing the expected inner product of. + Used for testing. + """ + raise NotImplementedError + + +class ReLUKernel(ElementwiseKernel): + def __init__(self, name=None): + super(ReLUKernel, self).__init__() + + def K(self, cov, var1, var2=None): + if var2 is None: + sqrt1 = sqrt2 = tf.sqrt(var1) + else: + sqrt1, sqrt2 = tf.sqrt(var1), tf.sqrt(var2) + + norms_prod = sqrt1[:, None, ...] * sqrt2 + norms_prod = tf.reshape(norms_prod, tf.shape(cov)) + + cos_theta = tf.clip_by_value(cov / norms_prod, -1., 1.) + theta = tf.acos(cos_theta) # angle wrt the previous RKHS + + sin_theta = tf.sqrt(1. - cos_theta**2) + J = sin_theta + (np.pi - theta) * cos_theta + div = 2*np.pi + return norms_prod / div * J + + def Kdiag(self, var): + return var/2 + + def nlin(self, x): + return tf.nn.relu(x) diff --git a/kernel/resnet.py b/kernel/resnet.py new file mode 100644 index 0000000..b517d8c --- /dev/null +++ b/kernel/resnet.py @@ -0,0 +1,125 @@ +# Copyright 2019 Jiaxin Shi +# Copyright 2017 https://github.com/rhaps0dy/convnets-as-gps +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import gpflowSlim as gpflow +import tensorflow as tf + +from .conv import ConvKernelBase + + +@gpflow.decors.name_scope() +def fixed_padding(inputs, kernel_size, data_format): + """Pads the input along the spatial dimensions independently of input size. + + Args: + inputs: A tensor of size [batch, channels, height_in, width_in]. + kernel_size: The kernel to be used in the conv2d or max_pool2d + operation. Should be a positive integer. + Returns: + A tensor with the same format as the input with the data either intact + (if kernel_size == 1) or padded (if kernel_size > 1). + """ + pad_total = kernel_size - 1 + if pad_total == 0: + return inputs + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + if data_format == "NCHW": + padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], + [pad_beg, pad_end], [pad_beg, pad_end]]) + else: + padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], + [pad_beg, pad_end], [0, 0]]) + return padded_inputs + + +def conv2d_fixed_padding(inputs, var, kernel_size, strides, + data_format, name='conv2d_fixed_padding'): + """Strided 2-D convolution with explicit padding.""" + # The padding is consistent and is based only on `kernel_size`, not on the + # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). + with tf.name_scope(name): + if strides > 1: + inputs = fixed_padding(inputs, kernel_size, data_format) + chan_idx = data_format.index("C") + try: + C = int(inputs.shape[chan_idx]) + except TypeError: + C = tf.shape(inputs)[chan_idx] + fan_in = C * (kernel_size * kernel_size) + W = tf.fill([kernel_size, kernel_size, C, 1], var/fan_in) + if data_format == "NCHW": + strides_shape = [1, 1, strides, strides] + else: + strides_shape = [1, strides, strides, 1] + return tf.nn.conv2d( + input=inputs, filter=W, strides=strides_shape, + padding=('SAME' if strides == 1 else 'VALID'), + data_format=data_format) + + +class ResnetKernel(ConvKernelBase): + "Kernel equivalent to Resnet V2 (tensorflow/models/official/resnet)" + @gpflow.decors.name_scope() + def headless_network(self, inputs, apply_recurse_kern): + """ + Apply the network that this kernel defines, except the last dense layer. + The last dense layer is different for K and Kdiag. + """ + # Copy from resnet_model.py + inputs = conv2d_fixed_padding( + inputs=inputs, var=self.var_weight, + kernel_size=self.kernel_size, + strides=self.conv_stride, + data_format=self.data_format, + name='initial_conv') + + for i, num_blocks in enumerate(self.block_sizes): + with tf.name_scope("block_layer_{}".format(i+1)): + # Only the first block per block_layer uses strides + # and strides + inputs = self.block_v2(inputs, True, self.block_strides[i], + apply_recurse_kern) + print("First layer of block {}:".format(i), inputs) + for j in range(1, num_blocks): + inputs = self.block_v2(inputs, False, 1, apply_recurse_kern) + print("{}th layer of block {}:".format(j, i), inputs) + # Dense layer + inputs = tf.reduce_mean(inputs, axis=(1, 2, 3)) + return self.var_bias + self.var_weight * inputs + + @gpflow.decors.name_scope() + def block_v2(self, inputs, projection_shortcut, strides, apply_recurse_kern): + shortcut = inputs + inputs = apply_recurse_kern(inputs) + if projection_shortcut: + # Need to project the inputs to a smaller space and also apply ReLU + del shortcut + shortcut = conv2d_fixed_padding( + inputs=inputs, var=self.var_weight, kernel_size=1, + strides=strides, data_format=self.data_format, + name='projection_shortcut') + + inputs = conv2d_fixed_padding( + inputs=inputs, var=self.var_weight, kernel_size=3, strides=strides, + data_format=self.data_format) + inputs = apply_recurse_kern(inputs) + inputs = conv2d_fixed_padding( + inputs=inputs, var=self.var_weight, kernel_size=3, strides=1, + data_format=self.data_format) + return inputs + shortcut