Skip to content

Commit

Permalink
add MMDWitness class for representing an MMD witness function
Browse files Browse the repository at this point in the history
  • Loading branch information
wittawatj committed May 13, 2018
1 parent f9b95ad commit b04d5da
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 39 deletions.
7 changes: 5 additions & 2 deletions freqopttest/data.py
Expand Up @@ -350,11 +350,13 @@ class SSGaussVarDiff(SampleSource):
P = N(0, I), Q = N(0, diag((2, 1, 1, ...))). Only the variances of the first
dimension differ."""

def __init__(self, d):
def __init__(self, d, var_d1=2.0):
"""
d: dimension of the data
var_d1: variance of the first dimension. 2 by default.
"""
self.d = d
self.var_d1 = var_d1

def dim(self):
return self.d
Expand All @@ -364,7 +366,8 @@ def sample(self, n, seed):
np.random.seed(seed)

d = self.d
std_y = np.diag(np.hstack((np.sqrt(2.0), np.ones(d-1) )))
var_d1 = self.var_d1
std_y = np.diag(np.hstack((np.sqrt(var_d1), np.ones(d-1) )))
X = np.random.randn(n, d)
Y = np.random.randn(n, d).dot(std_y)
np.random.set_state(rstate)
Expand Down
4 changes: 2 additions & 2 deletions freqopttest/glo.py
Expand Up @@ -66,7 +66,7 @@ def ex_save_result(ex, result, *relative_path):
dir_path = os.path.dirname(fpath)
create_dirs(dir_path)
#
with open(fpath, 'w') as f:
with open(fpath, 'wb') as f:
# expect result to be a dictionary
pickle.dump(result, f)

Expand All @@ -85,7 +85,7 @@ def pickle_load(fpath):
if not os.path.isfile(fpath):
raise ValueError('%s does not exist' % fpath)

with open(fpath, 'r') as f:
with open(fpath, 'rb') as f:
# expect a dictionary
result = pickle.load(f)
return result
Expand Down
94 changes: 66 additions & 28 deletions freqopttest/tst.py
Expand Up @@ -279,33 +279,6 @@ def permutation_list_mmd2(X, Y, k, n_permute=400, seed=8273):
for each permutation. We might be able to improve this if needed.
"""
return QuadMMDTest.permutation_list_mmd2_gram(X, Y, k, n_permute, seed)
#rand_state = np.random.get_state()
#np.random.seed(seed)

#XY = np.vstack((X, Y))
#nxy = XY.shape[0]
#nx = X.shape[0]
#ny = Y.shape[0]
#list_mmd2 = np.zeros(n_permute)
#for r in range(n_permute):
# ind = np.random.choice(nxy, nxy, replace=False)
# # divide into new X, Y
# Xr = XY[ind[:nx]]
# Yr = XY[ind[nx:]]
# mmd2r, var = QuadMMDTest.h1_mean_var(Xr, Yr, k, is_var_computed=False)
# list_mmd2[r] = mmd2r

#np.random.set_state(rand_state)
#return list_mmd2

@staticmethod
def permutation_list_mmd2_rahul(X, Y, k, n_permute=400, seed=8273):
""" Permutation by maintaining inverse indices. This approach is due to
Rahul (Soumyajit De) briefly described in "Generative Models and Model
Criticism via Optimized Maximum Mean Discrepancy" """
pass



@staticmethod
def permutation_list_mmd2_gram(X, Y, k, n_permute=400, seed=8273):
Expand Down Expand Up @@ -1009,7 +982,7 @@ def ustat_h1_mean_variance(feature_matrix, return_variance=True,
if return_variance:
# compute the variance
mu = np.mean(Z, axis=0) # length-J vector
variance = 4*np.mean(np.dot(Z, mu)**2) - 4*np.sum(mu**2)**2
variance = 4.0*np.mean(np.dot(Z, mu)**2) - 4.0*np.sum(mu**2)**2
return mean_h1, variance
else:
return mean_h1
Expand Down Expand Up @@ -1143,6 +1116,71 @@ def flat_obj(x):

# end of class GaussUMETest

class MMDWitness(object):
"""
Construct a callable object representing the (empirically estimated) MMD
witness function. The witness function g is defined as in Section 2.3 of
Gretton, Arthur, et al.
"A kernel two-sample test."
Journal of Machine Learning Research 13.Mar (2012): 723-773.
The witness function requires taking two expectations over the two sample
generating distributionls. This is approximated by two empirical
expectations using the samples. The witness function
is a real function, which depends on the kernel k and two fixed samples.
The constructed object can be called as if it is a function: (J x d) numpy
array |-> length-J numpy array.
"""

def __init__(self, k, X, Y):
"""
:params k: a Kernel
:params X: a sample from p
:params Y: a sample from q
"""
self.k = k
self.X = X
self.Y = Y

def __call__(self, V):
"""
:params V: a numpy array of size J x d (data matrix)
:returns a one-dimensional length-J numpy array representing witness
evaluations at the J points.
"""
J = V.shape[0]
k = self.k
X = self.X
Y = self.Y
n, d = X.shape

# When X, V contain many points, this can use a lot of memory.
# Process chunk by chunk.
block_rows = util.constrain(50000//d, 10, 5000)
sum_rows = []
for (f, t) in util.ChunkIterable(start=0, end=n, chunk_size=block_rows):
assert f<t
Xblk = X[f:t, :]
Yblk = Y[f:t, :]
# kernel evaluations
# b x J
Kx = k.eval(Xblk, V)
Ky = k.eval(Yblk, V)
# witness evaluations computed on only a subset of data
# ATTENTION: summing (instead of avf) may cause an overflow?
sum_rows.append((Kx-Ky).sum(axis=0))

# an array of length J
witness_evals = np.sum(np.vstack(sum_rows), axis=0)/float(n)
assert len(witness_evals) == J
return witness_evals

# end of class SteinWitness



class METest(TwoSampleTest):
"""
Expand Down
26 changes: 26 additions & 0 deletions freqopttest/util.py
Expand Up @@ -54,6 +54,32 @@ def __exit__(self, *args):

# end class NumpySeedContext

class ChunkIterable(object):
"""
Construct an Iterable such that each call to its iterator returns a tuple
of two indices (f, t) where f is the starting index, and t is the ending
index of a chunk. f and t are (chunk_size) apart except for the last tuple
which will always cover the rest.
"""
def __init__(self, start, end, chunk_size):
self.start = start
self.end = end
self.chunk_size = chunk_size

def __iter__(self):
s = self.start
e = self.end
c = self.chunk_size
# Probably not a good idea to use list. Waste memory.
L = list(range(s, e, c))
L.append(e)
return zip(L, L[1:])

# end ChunkIterable

def constrain(val, min_val, max_val):
return min(max_val, max(min_val, val))

def dist_matrix(X, Y):
"""
Construct a pairwise Euclidean distance matrix of size X.shape[0] x Y.shape[0]
Expand Down
102 changes: 95 additions & 7 deletions ipynb/quad_mmd_test.ipynb
Expand Up @@ -12,7 +12,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
Expand All @@ -38,11 +40,11 @@
"outputs": [],
"source": [
"# sample source \n",
"n = 300\n",
"n = 800\n",
"dim = 1\n",
"seed = 14\n",
"alpha = 0.01\n",
"# ss = data.SSGaussMeanDiff(dim, my=1)\n",
"ss = data.SSGaussMeanDiff(dim, my=1)\n",
"ss = data.SSGaussVarDiff(dim)\n",
"#ss = data.SSSameGauss(dim)\n",
"#ss = data.SSBlobs()\n",
Expand Down Expand Up @@ -82,13 +84,94 @@
"source": [
"start = time.time()\n",
"\n",
"perm_mmds1 = tst.QuadMMDTest.permutation_list_mmd2(xtr, ytr, k, n_permute=20)\n",
"perm_mmds1 = tst.QuadMMDTest.permutation_list_mmd2(xtr, ytr, k, n_permute=200)\n",
"\n",
"end = time.time()\n",
"print('permutations took: %.4f s'%(end-start))\n",
"print('perm_mmds1', perm_mmds1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def chi_square_weights_H0(k, X):\n",
" \"\"\"\n",
" Return a numpy array of the weights to be used as the weights in the\n",
" weighted sum of chi-squares for the null distribution of MMD^2.\n",
" - k: a Kernel\n",
" - X: n x d number array of n data points\n",
" \"\"\"\n",
" n = X.shape[0]\n",
" # Gram matrix\n",
" K = k.eval(X, X)\n",
" # centring matrix. Not the most efficient way.\n",
" H = np.eye(n) - np.ones((n, n))/float(n)\n",
" HKH = H.dot(K).dot(H)\n",
" #https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.linalg.eigvals.html\n",
" evals = np.linalg.eigvals(HKH)\n",
" evals = np.real(evals)\n",
" # sort in decreasing order \n",
" evals = -np.sort(-evals)\n",
" weights = evals/float(n)**2\n",
" return weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def simulate_null_spectral(weights, n_simulate=1000, seed=275):\n",
" \"\"\"\n",
" weights: chi-square weights (for the infinite weigted sum of chi squares) \n",
" Return the values of MMD^2 (NOT n*MMD^2) simulated from the null distribution by\n",
" the spectral method.\n",
" \"\"\"\n",
" # draw at most block_size values at a time\n",
" block_size = 400\n",
" D = len(weights)\n",
" mmds = np.zeros(n_simulate)\n",
" from_ind = 0\n",
"\n",
" with util.NumpySeedContext(seed=seed):\n",
" while from_ind < n_simulate:\n",
" to_draw = min(block_size, n_simulate-from_ind)\n",
" # draw chi^2 random variables. \n",
" chi2 = np.random.randn(D, to_draw)**2\n",
" # an array of length to_draw \n",
" sim_mmds = 2.0*weights.dot(chi2-1.0)\n",
" # store \n",
" end_ind = from_ind+to_draw\n",
" mmds[from_ind:end_ind] = sim_mmds\n",
" from_ind = end_ind\n",
" return mmds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"xytr = np.vstack((xtr, ytr))\n",
"chi2_weights = chi_square_weights_H0(k, xytr)\n",
"sim_mmds = simulate_null_spectral(chi2_weights, n_simulate=2000)\n",
"a = 0.6\n",
"\n",
"plt.figure(figsize=(8, 5))\n",
"plt.hist(perm_mmds1,20, color='blue', normed=True, label='Permutation', alpha=a)\n",
"plt.hist(sim_mmds, 20, color='red', normed=True, label='Spectral', alpha=a)\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -144,7 +227,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -240,6 +323,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
Expand All @@ -251,7 +335,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# terms"
Expand All @@ -260,7 +346,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# perm_mmds = test_result['list_permuted_mmd2']\n",
Expand Down

0 comments on commit b04d5da

Please sign in to comment.