Permalink
Fetching contributors…
Cannot retrieve contributors at this time
204 lines (151 sloc) 5 KB
# -*- coding: utf-8 -*-
"""
Dimension reduction with optimal transport
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License
from scipy import linalg
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
def sinkhorn(w1, w2, M, reg, k):
"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
vi = w2 / (np.dot(K.T, ui))
ui = w1 / (np.dot(K, vi))
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
def split_classes(X, y):
"""split samples in X by classes in y
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
"""
Fisher Discriminant Analysis
Parameters
----------
X : numpy.ndarray (n,d)
Training samples
y : np.ndarray (n,)
labels for training samples
p : int, optional
size of dimensionnality reduction
reg : float, optional
Regularization term >0 (ridge regularization)
Returns
-------
P : (d x p) ndarray
Optimal transportation matrix for the given parameters
proj : fun
projection function including mean centering
"""
mx = np.mean(X)
X -= mx.reshape((1, -1))
# data split between classes
d = X.shape[1]
xc = split_classes(X, y)
nc = len(xc)
p = min(nc - 1, p)
Cw = 0
for x in xc:
Cw += np.cov(x, rowvar=False)
Cw /= nc
mxc = np.zeros((d, nc))
for i in range(nc):
mxc[:, i] = np.mean(xc[i])
mx0 = np.mean(mxc, 1)
Cb = 0
for i in range(nc):
Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \
(mxc[:, i] - mx0).reshape((1, -1))
w, V = linalg.eig(Cb, Cw + reg * np.eye(d))
idx = np.argsort(w.real)
Popt = V[:, idx[-p:]]
def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
"""
Wasserstein Discriminant Analysis [11]_
The function solves the following optimization problem:
.. math::
P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
where :
- :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
- :math:`W` is entropic regularized Wasserstein distances
- :math:`X^i` are samples in the dataset corresponding to class i
Parameters
----------
X : numpy.ndarray (n,d)
Training samples
y : np.ndarray (n,)
labels for training samples
p : int, optional
size of dimensionnality reduction
reg : float, optional
Regularization term >0 (entropic regularization)
solver : str, optional
None for steepest decsent or 'TrustRegions' for trust regions algorithm
else shoudl be a pymanopt.solvers
P0 : numpy.ndarray (d,p)
Initial starting point for projection
verbose : int, optional
Print information along iterations
Returns
-------
P : (d x p) ndarray
Optimal transportation matrix for the given parameters
proj : fun
projection function including mean centering
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)
X -= mx.reshape((1, -1))
# data split between classes
d = X.shape[1]
xc = split_classes(X, y)
# compute uniform weighs
wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
def cost(P):
# wda loss
loss_b = 0
loss_w = 0
for i, xi in enumerate(xc):
xi = np.dot(xi, P)
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
G = sinkhorn(wc[i], wc[j + i], M, reg, k)
if j == 0:
loss_w += np.sum(G * M)
else:
loss_b += np.sum(G * M)
# loss inversed because minimization
return loss_w / loss_b
# declare manifold and problem
manifold = Stiefel(d, p)
problem = Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
Popt = solver.solve(problem, x=P0)
def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj