In [1]:
import torch
import numpy as np
from typing import List
import scipy
from scipy import linalg
torch.set_default_tensor_type(torch.DoubleTensor)

In [22]:
N = 1000000
d = 32
l = 5
H = (torch.randn(N,d) + torch.rand(N,d)**2)

H = H - torch.mean(H, dim = 0, keepdim = True)
H = H
W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))


In [23]:
# u = torch.nn.Parameter(u)
# optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [24]:
def get_cov_output_projected(u, cov_H, W):
    
    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    return W.T@P@cov_H@P@W

def get_cov_output_total(H,W):
    Y_hat = H@W 
    return torch.tensor(np.cov(Y_hat.detach().cpu().numpy(), rowvar = False))

def get_loss_func(cov_output_projected):
    
    return torch.sum(torch.diag(cov_output_projected))

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def BCA(H,W,n_components, eps = 1e-7, max_iters = 25000):
    
    P_nullspace = torch.eye(H.shape[1])
    results = []
    cov_out_total = get_cov_output_total(H,W)
    total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()
    H_proj = H.clone()
    rowspace_projs = []
    
    for i in range(n_components):
        
        H_proj = H_proj@P_nullspace # remove previous component 
        #if i > 0: print("test: ", H_proj@u.double())
        #print("H proj", H_proj[:10,:])
        cov_H = torch.from_numpy(np.cov(H_proj.detach().cpu().numpy(), rowvar = False))
        #print("COV H proj", cov_H)
        print("-----------------------------")
        u = torch.randn(H_proj.shape[1], 1)
        u = torch.nn.Parameter(u)
        optimizer = torch.optim.SGD([u], lr=1e-3, momentum=0.9, weight_decay=1e-6)
        
        diff = 10
        j = 0
        loss_vals = [np.inf]
        
        while j < max_iters and diff > eps:
            optimizer.zero_grad()
            cov_out = get_cov_output_projected(u,cov_H,W)
            loss = get_loss_func(cov_out)
            loss.backward()
            optimizer.step()
            loss_vals.append(loss.detach().cpu().numpy().item())
            diff = np.abs(loss_vals[-1] - loss_vals[-2])
            if j % 500 == 0: print("j, loss, ", j, loss.detach().cpu().numpy().item(), diff)
            j += 1
        print("finished after {} iters".format(j))
        
        # calculate new nullspace projection to neutralzie component u
        
        u_normed = u / torch.norm(u)
        rowspace_projs.append((u_normed@u_normed.T).detach().cpu().numpy())
        #P_nullspace = torch.from_numpy(get_projection_to_intersection_of_nullspaces(rowspace_projs,cov_H.shape[0]))
        P_nullspace = torch.eye(H_proj.shape[1]).double() - u_normed@u_normed.T
        
        # calcualte explained variance
        cov_out_projected = get_cov_output_projected(u,cov_H,W)
        total_var_projected = total_var-get_loss_func(cov_out_projected).detach().cpu().numpy().item()
        explained_var = total_var_projected / total_var
        
        u = u / u.norm()
        results.append({"vec": u.squeeze().detach().cpu().numpy(), "projected_var": total_var_projected,
                       "total_var": total_var, "explained_var": total_var_projected*100/total_var,
                       "cov_out":cov_out_projected})
    
    return results
        

In [28]:
bca = BCA(H,W,n_components=5)

-----------------------------
j, loss,  0 135.26160273520554 inf
j, loss,  500 100.55800708361618 0.0035053699756133483
j, loss,  1000 99.9186858463355 0.0005054382567237781
j, loss,  1500 99.7617881388654 0.00017777623224901618
j, loss,  2000 99.7078965576554 5.859931863483325e-05
j, loss,  2500 99.69045049489841 1.862443282618642e-05
j, loss,  3000 99.6849391273175 5.847988333584908e-06
j, loss,  3500 99.68321197067934 1.8290295429324033e-06
j, loss,  4000 99.68267212884751 5.713088597758542e-07
j, loss,  4500 99.68250354281327 1.7837221832905925e-07
finished after 4750 iters
-----------------------------
j, loss,  0 99.1276355328844 inf
j, loss,  500 63.05948957976347 0.0027859741574047803
j, loss,  1000 62.67309401257155 7.575458378994426e-05
j, loss,  1500 62.663198902056045 1.8275550885960001e-06
finished after 1891 iters
-----------------------------
j, loss,  0 61.444175393820906 inf
j, loss,  500 39.99108873018874 0.030079992229133268
j, loss,  1000 32.36511302457802 0.0003949

In [29]:
#bca[0]["vec"].T@bca[3]["vec"]

In [30]:
vecs = np.array([x["vec"] for x in bca])
for i,v in enumerate(vecs):
    for j, v2 in enumerate(vecs):
        if j <= i: continue
            
        print(i,j, v@v2.T)

0 1 -1.265984961192762e-06
0 2 -1.407897433206018e-06
0 3 -1.1746536428269838e-06
0 4 -1.6790201201599686e-08
1 2 -5.058391205903234e-10
1 3 -4.4132394230039784e-07
1 4 -8.326525940960394e-09
2 3 -1.1876170933067254e-06
2 4 -2.227892656470054e-08
3 4 9.776601021804776e-09


In [31]:

#for i in range(len(bca)):
    
    #print(bca[i]["explained_var"])
    
    
print(sum([x["explained_var"] for x in bca]))

350.39296051290836


In [34]:
bca[-2]["cov_out"]

tensor([[ 4.0669, -3.6262, -0.9620,  1.9986,  4.2501],
        [-3.6262,  3.2333,  0.8578, -1.7820, -3.7896],
        [-0.9620,  0.8578,  0.2276, -0.4728, -1.0054],
        [ 1.9986, -1.7820, -0.4728,  0.9822,  2.0886],
        [ 4.2501, -3.7896, -1.0054,  2.0886,  4.4416]], grad_fn=<MmBackward>)

In [10]:
bca[1]["explained_var"]

81.3286209252445

In [16]:
bca[0]["vec"].T@bca[2]["vec"]

8.559765125870644e-09

In [12]:
H.shape
u1, u2 = torch.randn(2,1), torch.randn(2,1)
u1 = u1 / torch.norm(u1)
u2 = u2 / torch.norm(u2)
P_u1 = P_u1 = torch.eye(2) - u1@u1.T
P_u2 = torch.eye(2) - u2@u2.T
P_u1= P_u1
P_u2 = P_u2
H@P_u1@P_u2

RuntimeError: size mismatch, m1: [1000000 x 32], m2: [2 x 2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

In [None]:
P_u1 = torch.eye(2) - u1@u1.T
P_u2

In [None]:
torch.allclose