Skip to content

Commit

Permalink
add classification code
Browse files Browse the repository at this point in the history
  • Loading branch information
thjashin committed May 27, 2019
1 parent 4b1ec41 commit 8f4cf9c
Show file tree
Hide file tree
Showing 9 changed files with 948 additions and 0 deletions.
147 changes: 147 additions & 0 deletions classification.py
@@ -0,0 +1,147 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function
from __future__ import division
import os
from collections import namedtuple

import tensorflow as tf
import numpy as np

from kernel.resnet import ResnetKernel
from kernel.elementwise import ReLUKernel
from utils.log import setup_logger
from utils.data import load_mnist_realval, load_cifar10
from classification import svgp, gpnet, gpnet_nonconj, fbnn


FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("dataset", "mnist", "Dataset.")
tf.flags.DEFINE_string("method", "gpnet_nonconj", """Inference method.""")
tf.flags.DEFINE_string("net", "tangent", "Inference network.")
tf.flags.DEFINE_integer("batch_size", 128, """Total batch size.""")
tf.flags.DEFINE_float("learning_rate", 3e-4, """Learning rate.""")
tf.flags.DEFINE_integer("n_inducing", 100, """Number of inducing points.""")
tf.flags.DEFINE_string("measure", "train", "Measurement set.")
tf.flags.DEFINE_float("hyper_rate", 0, "Hyperparameter update rate.")
tf.flags.DEFINE_integer("block_size", 2, "number of blocks for each size.")
tf.flags.DEFINE_float("beta0", 0.01, """Initial beta value.""")
tf.flags.DEFINE_float("gamma", 0.1, """Beta schedule.""")
tf.flags.DEFINE_integer("n_iters", 10000, """Number of training iterations.""")
tf.flags.DEFINE_string("note", "", "Note for random experiments.")


def main():
flag_values = [
("method", FLAGS.method),
("net", FLAGS.net),
("inducing", FLAGS.n_inducing),
("beta0", FLAGS.beta0),
("gamma", FLAGS.gamma),
("niter", FLAGS.n_iters),
("bs", FLAGS.batch_size // 2),
("m", FLAGS.batch_size // 2),
("lr", FLAGS.learning_rate),
("measure", FLAGS.measure),
("hyper_rate", FLAGS.hyper_rate),
("block", FLAGS.block_size),
("note", FLAGS.note),
]
flag_str = "$".join(["@".join([i[0], str(i[1])]) for i in flag_values])
result_path = os.path.join(
"results", "classification", FLAGS.dataset, flag_str)
logger = setup_logger("classification", __file__, result_path,
filename="log")

np.random.seed(1234)
tf.set_random_seed(1234)

# Load MNIST
if FLAGS.dataset == "mnist":
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist_realval(
dtype=np.float64)
train_x = np.vstack([train_x, valid_x])
train_y = np.vstack([train_y, valid_y])
input_shape = [1, 28, 28]
elif FLAGS.dataset == "cifar10":
train_x, train_y, test_x, test_y = load_cifar10(
dtype=np.float64)
input_shape = [3, 32, 32]
else:
raise NotImplementedError()

train_x = 2 * train_x - 1
test_x = 2 * test_x - 1

train = tf.data.Dataset.from_tensor_slices((train_x, train_y))
test = tf.data.Dataset.from_tensor_slices((test_x, test_y))
train = train.shuffle(buffer_size=1000).batch(
FLAGS.batch_size // 2).repeat()
test = test.batch(FLAGS.batch_size * 4)

if FLAGS.measure == "test_x":
measure = tf.data.Dataset.from_tensor_slices(test_x)
else:
measure = tf.data.Dataset.from_tensor_slices(train_x)
measure = measure.shuffle(buffer_size=1000).batch(
FLAGS.batch_size // 2).repeat()
measure_iterator = measure.make_one_shot_iterator()
measure_batch = measure_iterator.get_next()

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train.output_types, train.output_shapes)
next_batch = iterator.get_next()

train_iterator = train.make_one_shot_iterator()
test_iterator = test.make_initializable_iterator()

sess = tf.Session()

train_handle = sess.run(train_iterator.string_handle())
test_handle = sess.run(test_iterator.string_handle())

Data = namedtuple("Data", [
"next_batch",
"measure_batch",
"handle",
"train_handle",
"test_handle",
"test_iterator",
"train_x",
"train_y"])
data = Data(next_batch, measure_batch, handle, train_handle, test_handle,
test_iterator, train_x, train_y)

block_sizes = [FLAGS.block_size] * 3
block_strides = [1, 2, 2]
with tf.variable_scope("prior"):
resnet_kern = ResnetKernel(
input_shape=input_shape,
block_sizes=block_sizes,
block_strides=block_strides,
kernel_size=3,
recurse_kern=ReLUKernel(),
var_weight=1.,
var_bias=0.,
conv_stride=1,
data_format="NCHW",
dtype=tf.float64,
)

sess.run(tf.variables_initializer(tf.trainable_variables("prior")))

# SVGP
if FLAGS.method == "svgp":
svgp(logger, sess, data, resnet_kern)
elif FLAGS.method == "gpnet":
gpnet(logger, sess, data, resnet_kern, dtype=tf.float64)
elif FLAGS.method == "gpnet_nonconj":
gpnet_nonconj(logger, sess, data, resnet_kern, dtype=tf.float64)
elif FLAGS.method == "fbnn":
fbnn(logger, sess, data, resnet_kern, dtype=tf.float64)


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions classification/__init__.py
@@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .gpnet import gpnet, gpnet_nonconj
from .svgp import svgp
from .fbnn import fbnn
100 changes: 100 additions & 0 deletions classification/fbnn.py
@@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np
import zhusuan as zs
import gpflowSlim as gpflow

from bnn.multi_output.resnet import build_resnet
from utils.mvn import multivariate_normal_kl


FLAGS = tf.flags.FLAGS


def fbnn(logger, sess, data, kernel, dtype=tf.float64):
train_x = data.train_x
N, x_dim = train_x.shape
_, n_cls = data.train_y.shape

bnn = build_resnet(
"bnn",
n_cls,
kernel.input_shape,
kernel.block_sizes,
kernel.block_strides,
data_format="NCHW",
dtype=dtype,
net=FLAGS.net)

x, y = data.next_batch
# x_star: [bs_star, x_dim], x: [bs, x_dim]
x_star = data.measure_batch
# xx: [bs + bs_star, x_dim]
xx = tf.concat([x, x_star], axis=0)

qff = bnn(xx)
# qf_mean: [n_cls, bs], qf_var: [n_cls, bs], f_pred: [n_cls, bs]
qf_mean, qf_var = bnn(x, full_cov=False)
f_pred = qf_mean + tf.sqrt(qf_var) * tf.random_normal(tf.shape(qf_mean),
dtype=dtype)
# y_pred: [bs]
y_pred = tf.argmax(qf_mean, axis=0, output_type=tf.int32)
# y_target: [bs]
y_target = tf.argmax(y, axis=1, output_type=tf.int32)
# acc: []
acc = tf.reduce_mean(tf.to_float(tf.equal(y_pred, y_target)))

# K_prior: [bs + bs_star, bs + bs_star]
K_prior = kernel.K(xx)
K_prior_tril = tf.cholesky(
K_prior + tf.eye(tf.shape(xx)[0], dtype=dtype) * gpflow.settings.jitter)
pff = zs.distributions.MultivariateNormalCholesky(
tf.zeros([n_cls, tf.shape(xx)[0]], dtype=dtype),
tf.tile(K_prior_tril[None, ...], [n_cls, 1, 1]))

# likelihood term
f_term = -tf.nn.softmax_cross_entropy_with_logits(
labels=y,
logits=tf.matrix_transpose(f_pred))
f_term = tf.reduce_sum(f_term)

# kl term
kl_term = tf.reduce_sum(multivariate_normal_kl(qff, pff))

lower_bound = f_term - kl_term

fbnn_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
bnn_var = tf.trainable_variables(scope="bnn")
infer_fbnn = fbnn_opt.minimize(-lower_bound, var_list=bnn_var)
print_freq = 1
test_freq = 100
sess.run(tf.variables_initializer(var_list=bnn_var + fbnn_opt.variables()))
train_stats = []
for t in range(1, FLAGS.n_iters + 1):
_, train_ll, train_acc = sess.run(
[infer_fbnn, lower_bound, acc],
feed_dict={data.handle: data.train_handle})
train_stats.append((train_ll, train_acc))

if t % print_freq == 0:
train_lls, train_accs = list(zip(*train_stats))
logger.info("Iter {}, lower bound = {:.4f}, train acc = {:.4f}"
.format(t, np.mean(train_lls), np.mean(train_accs)))
train_stats = []

if t % test_freq == 0:
sess.run(data.test_iterator.initializer)
test_stats = []
while True:
try:
test_stats.append(
sess.run(acc,
feed_dict={data.handle: data.test_handle}))
except tf.errors.OutOfRangeError:
break
logger.info(">> Test acc = {:.4f}".format(np.mean(test_stats)))

0 comments on commit 8f4cf9c

Please sign in to comment.