diff --git a/freqopttest/data.py b/freqopttest/data.py index 2aabb40..1926a0a 100644 --- a/freqopttest/data.py +++ b/freqopttest/data.py @@ -350,11 +350,13 @@ class SSGaussVarDiff(SampleSource): P = N(0, I), Q = N(0, diag((2, 1, 1, ...))). Only the variances of the first dimension differ.""" - def __init__(self, d): + def __init__(self, d, var_d1=2.0): """ d: dimension of the data + var_d1: variance of the first dimension. 2 by default. """ self.d = d + self.var_d1 = var_d1 def dim(self): return self.d @@ -364,7 +366,8 @@ def sample(self, n, seed): np.random.seed(seed) d = self.d - std_y = np.diag(np.hstack((np.sqrt(2.0), np.ones(d-1) ))) + var_d1 = self.var_d1 + std_y = np.diag(np.hstack((np.sqrt(var_d1), np.ones(d-1) ))) X = np.random.randn(n, d) Y = np.random.randn(n, d).dot(std_y) np.random.set_state(rstate) diff --git a/freqopttest/glo.py b/freqopttest/glo.py index 0d8bde1..98b9b73 100644 --- a/freqopttest/glo.py +++ b/freqopttest/glo.py @@ -66,7 +66,7 @@ def ex_save_result(ex, result, *relative_path): dir_path = os.path.dirname(fpath) create_dirs(dir_path) # - with open(fpath, 'w') as f: + with open(fpath, 'wb') as f: # expect result to be a dictionary pickle.dump(result, f) @@ -85,7 +85,7 @@ def pickle_load(fpath): if not os.path.isfile(fpath): raise ValueError('%s does not exist' % fpath) - with open(fpath, 'r') as f: + with open(fpath, 'rb') as f: # expect a dictionary result = pickle.load(f) return result diff --git a/freqopttest/tst.py b/freqopttest/tst.py index 0c43348..43d24c5 100644 --- a/freqopttest/tst.py +++ b/freqopttest/tst.py @@ -279,33 +279,6 @@ def permutation_list_mmd2(X, Y, k, n_permute=400, seed=8273): for each permutation. We might be able to improve this if needed. """ return QuadMMDTest.permutation_list_mmd2_gram(X, Y, k, n_permute, seed) - #rand_state = np.random.get_state() - #np.random.seed(seed) - - #XY = np.vstack((X, Y)) - #nxy = XY.shape[0] - #nx = X.shape[0] - #ny = Y.shape[0] - #list_mmd2 = np.zeros(n_permute) - #for r in range(n_permute): - # ind = np.random.choice(nxy, nxy, replace=False) - # # divide into new X, Y - # Xr = XY[ind[:nx]] - # Yr = XY[ind[nx:]] - # mmd2r, var = QuadMMDTest.h1_mean_var(Xr, Yr, k, is_var_computed=False) - # list_mmd2[r] = mmd2r - - #np.random.set_state(rand_state) - #return list_mmd2 - - @staticmethod - def permutation_list_mmd2_rahul(X, Y, k, n_permute=400, seed=8273): - """ Permutation by maintaining inverse indices. This approach is due to - Rahul (Soumyajit De) briefly described in "Generative Models and Model - Criticism via Optimized Maximum Mean Discrepancy" """ - pass - - @staticmethod def permutation_list_mmd2_gram(X, Y, k, n_permute=400, seed=8273): @@ -1009,7 +982,7 @@ def ustat_h1_mean_variance(feature_matrix, return_variance=True, if return_variance: # compute the variance mu = np.mean(Z, axis=0) # length-J vector - variance = 4*np.mean(np.dot(Z, mu)**2) - 4*np.sum(mu**2)**2 + variance = 4.0*np.mean(np.dot(Z, mu)**2) - 4.0*np.sum(mu**2)**2 return mean_h1, variance else: return mean_h1 @@ -1143,6 +1116,71 @@ def flat_obj(x): # end of class GaussUMETest +class MMDWitness(object): + """ + Construct a callable object representing the (empirically estimated) MMD + witness function. The witness function g is defined as in Section 2.3 of + + Gretton, Arthur, et al. + "A kernel two-sample test." + Journal of Machine Learning Research 13.Mar (2012): 723-773. + + The witness function requires taking two expectations over the two sample + generating distributionls. This is approximated by two empirical + expectations using the samples. The witness function + is a real function, which depends on the kernel k and two fixed samples. + + The constructed object can be called as if it is a function: (J x d) numpy + array |-> length-J numpy array. + """ + + def __init__(self, k, X, Y): + """ + :params k: a Kernel + :params X: a sample from p + :params Y: a sample from q + """ + self.k = k + self.X = X + self.Y = Y + + def __call__(self, V): + """ + :params V: a numpy array of size J x d (data matrix) + + :returns a one-dimensional length-J numpy array representing witness + evaluations at the J points. + """ + J = V.shape[0] + k = self.k + X = self.X + Y = self.Y + n, d = X.shape + + # When X, V contain many points, this can use a lot of memory. + # Process chunk by chunk. + block_rows = util.constrain(50000//d, 10, 5000) + sum_rows = [] + for (f, t) in util.ChunkIterable(start=0, end=n, chunk_size=block_rows): + assert f