In [62]:
import torch
import numpy as np
from typing import List
import scipy
from scipy import linalg
from sklearn.decomposition import PCA

torch.set_default_tensor_type(torch.DoubleTensor)

In [13]:
N = 1000000
d = 32
l = 5
H = (torch.randn(N,d) + torch.rand(N,d)**2)

H = H - torch.mean(H, dim = 0, keepdim = True)
H = H
W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))


In [14]:
# u = torch.nn.Parameter(u)
# optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [97]:
def get_cov_output_projected(u, cov_H, W):
    
    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    return W.T@P@cov_H@P@W

def get_cov_output_total(H,W):
    Y_hat = H@W 
    return torch.tensor(np.cov(Y_hat.detach().cpu().numpy(), rowvar = False))

def get_loss_func(cov_output_projected):
    
    return torch.sum(torch.diag(cov_output_projected))

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def get_first_pca(H):
    pca = PCA(n_components = 1)
    pca.fit(H)
    return torch.from_numpy(pca.components_.T)

def BCA(H,W,n_components, eps = 1e-8, max_iters = 25000, init_pca = True):
    
    P_nullspace = torch.eye(H.shape[1])
    results = []
    cov_out_total = get_cov_output_total(H,W)
    total_var_orig = get_loss_func(cov_out_total).detach().cpu().numpy().item()
    remaining_var = total_var_orig
    H_proj = H.clone()
    rowspace_projs = []
    
    for i in range(n_components):
        
        H_proj = H_proj@P_nullspace # remove previous component 
        cov_H = torch.from_numpy(np.cov(H_proj.detach().cpu().numpy(), rowvar = False))
        
        if init_pca:
            u = get_first_pca(H_proj.detach().cpu().numpy())
        else:
            u = torch.randn(H_proj.shape[1], 1)
        u = torch.nn.Parameter(u)
        optimizer = torch.optim.SGD([u], lr=1e-3, momentum=0.9, weight_decay=1e-10)
        
        diff = 10
        j = 0
        loss_vals = [np.inf]
        
        while j < max_iters and diff > eps:
            optimizer.zero_grad()
            cov_out = get_cov_output_projected(u,cov_H,W)
            loss = get_loss_func(cov_out)
            loss.backward()
            optimizer.step()
            loss_vals.append(loss.detach().cpu().numpy().item())
            diff = np.abs(loss_vals[-1] - loss_vals[-2])
            if j % 500 == 0: print("j, loss, ", j, loss.detach().cpu().numpy().item(), diff)
            j += 1
        print("finished after {} iters".format(j))
        
        # calculate new nullspace projection to neutralzie component u
        
        u_normed = u / torch.norm(u)
        rowspace_projs.append((u_normed@u_normed.T).detach().cpu().numpy())
        #P_nullspace = torch.from_numpy(get_projection_to_intersection_of_nullspaces(rowspace_projs,cov_H.shape[0]))
        P_nullspace = torch.eye(H_proj.shape[1]).double() - u_normed@u_normed.T
        
        # calcualte explained variance
        cov_out_total = get_cov_output_total(H,W)
        total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()
        cov_out_projected = get_cov_output_projected(u,cov_H,W)
        total_var_projected = remaining_var - get_loss_func(cov_out_projected).detach().cpu().numpy().item()
        explained_var = total_var_projected / total_var
        
        remaining_var = remaining_var - total_var_projected
        
        #u = u / u.norm()
        results.append({"vec": u.squeeze().detach().cpu().numpy(), "projected_var": total_var_projected,
                       "total_var": total_var, "explained_var": total_var_projected*100/total_var,
                       "cov_out":cov_out_projected})
    
    return results
        

In [101]:
bca = BCA(H,W,n_components=6, eps = 1e-9, init_pca = True)

j, loss,  0 169.22520978210827 inf
finished after 297 iters
j, loss,  0 120.34984077920129 inf
j, loss,  500 83.33519684677654 4.09292510994419e-09
finished after 537 iters
j, loss,  0 82.32590894700616 inf
j, loss,  500 50.65680543502892 2.142370014723838e-06
finished after 811 iters
j, loss,  0 48.93480232274766 inf
finished after 257 iters
j, loss,  0 21.740346904464143 inf
finished after 187 iters
j, loss,  0 7.609858549168947e-09 inf
finished after 2 iters


In [102]:
#bca[0]["vec"].T@bca[3]["vec"]

In [103]:
vecs = np.array([x["vec"] for x in bca])
for i,v in enumerate(vecs):
    for j, v2 in enumerate(vecs):
        if j <= i: continue
            
        print(i,j, v@v2.T)

0 1 -6.798173381561412e-06
0 2 9.931919446204418e-06
0 3 8.62089792696974e-06
0 4 -1.228707438571952e-07
0 5 1.0913743242468854e-08
1 2 5.137323258463944e-06
1 3 3.2546245907205673e-06
1 4 -4.277656528994811e-08
1 5 3.508004775554241e-09
2 3 3.672781243579948e-06
2 4 -2.4219421745907965e-08
2 5 -1.012295913760397e-09
3 4 1.2162730955722623e-07
3 5 -6.004375330270761e-10
4 5 3.909578286109827e-10


In [104]:

#for i in range(len(bca)):
    
    #print(bca[i]["explained_var"])
    
cov_out_total = get_cov_output_total(H,W)
total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()    
vars = [x["explained_var"] for x in bca]
print(vars)
print(sum(vars))

[28.415633282144302, 22.86761097608209, 19.10344517169494, 16.851799391595954, 12.761511173989295, 4.479159908069864e-11]
99.99999999555136


In [105]:
bca[-1]["vec"]

array([ 0.3132065 ,  0.14544315, -0.21388022, -0.07112003,  0.06993842,
        0.15482405,  0.24399994, -0.01227424,  0.00129624,  0.29998847,
       -0.03600201,  0.12333816,  0.05075786, -0.05182187,  0.12774499,
        0.20868128,  0.22070852,  0.30649608, -0.0399898 , -0.07595784,
       -0.22371177,  0.07208546,  0.01059946,  0.24092727,  0.19335357,
       -0.02625039,  0.0787781 ,  0.22569228,  0.25061547,  0.37341051,
        0.08615498, -0.03819199])

In [106]:
bca[-3]["explained_var"]

16.851799391595954

In [107]:
bca[0]["vec"].T@bca[2]["vec"]

9.931919446204418e-06