In [1]:
import torch
import numpy as np
from typing import List
import scipy
from scipy import linalg
torch.set_default_tensor_type(torch.DoubleTensor)

In [13]:
N = 1000000
d = 32
l = 5
H = (torch.randn(N,d) + torch.rand(N,d)**2)

H = H - torch.mean(H, dim = 0, keepdim = True)
H = H
W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))


In [14]:
# u = torch.nn.Parameter(u)
# optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [15]:
def get_cov_output_projected(u, cov_H, W):
    
    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    return W.T@P@cov_H@P@W

def get_cov_output_total(H,W):
    Y_hat = H@W 
    return torch.tensor(np.cov(Y_hat.detach().cpu().numpy(), rowvar = False))

def get_loss_func(cov_output_projected):
    
    return torch.sum(torch.diag(cov_output_projected))

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def BCA(H,W,n_components, eps = 1e-8, max_iters = 25000):
    
    P_nullspace = torch.eye(H.shape[1])
    results = []
    cov_out_total = get_cov_output_total(H,W)
    total_var_orig = get_loss_func(cov_out_total).detach().cpu().numpy().item()
    remaining_var = total_var_orig
    H_proj = H.clone()
    rowspace_projs = []
    
    for i in range(n_components):
        
        H_proj = H_proj@P_nullspace # remove previous component 
        #if i > 0: print("test: ", H_proj@u.double())
        #print("H proj", H_proj[:10,:])
        cov_H = torch.from_numpy(np.cov(H_proj.detach().cpu().numpy(), rowvar = False))
        #print("COV H proj", cov_H)
        print("-----------------------------")
        u = torch.randn(H_proj.shape[1], 1)
        u = torch.nn.Parameter(u)
        optimizer = torch.optim.SGD([u], lr=1e-3, momentum=0.9, weight_decay=1e-6)
        
        diff = 10
        j = 0
        loss_vals = [np.inf]
        
        while j < max_iters and diff > eps:
            optimizer.zero_grad()
            cov_out = get_cov_output_projected(u,cov_H,W)
            loss = get_loss_func(cov_out)
            loss.backward()
            optimizer.step()
            loss_vals.append(loss.detach().cpu().numpy().item())
            diff = np.abs(loss_vals[-1] - loss_vals[-2])
            if j % 500 == 0: print("j, loss, ", j, loss.detach().cpu().numpy().item(), diff)
            j += 1
        print("finished after {} iters".format(j))
        
        # calculate new nullspace projection to neutralzie component u
        
        u_normed = u / torch.norm(u)
        rowspace_projs.append((u_normed@u_normed.T).detach().cpu().numpy())
        #P_nullspace = torch.from_numpy(get_projection_to_intersection_of_nullspaces(rowspace_projs,cov_H.shape[0]))
        P_nullspace = torch.eye(H_proj.shape[1]).double() - u_normed@u_normed.T
        
        # calcualte explained variance
        cov_out_total = get_cov_output_total(H,W)
        total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()
        cov_out_projected = get_cov_output_projected(u,cov_H,W)
        total_var_projected = remaining_var - get_loss_func(cov_out_projected).detach().cpu().numpy().item()
        explained_var = total_var_projected / total_var
        
        remaining_var = remaining_var - total_var_projected
        
        u = u / u.norm()
        results.append({"vec": u.squeeze().detach().cpu().numpy(), "projected_var": total_var_projected,
                       "total_var": total_var, "explained_var": total_var_projected*100/total_var,
                       "cov_out":cov_out_projected})
    
    return results
        

In [16]:
bca = BCA(H,W,n_components=5)

-----------------------------
j, loss,  0 169.72049041549053 inf
j, loss,  500 122.4781739255136 0.00030679740329730976
j, loss,  1000 122.4527948697771 1.2034733032351141e-06
finished after 1463 iters
-----------------------------
j, loss,  0 122.2090252620225 inf
j, loss,  500 87.62471807844426 0.014922887538062923
j, loss,  1000 83.655355513856 0.0019122194108689428
j, loss,  1500 83.35161982039182 9.701867539035902e-05
j, loss,  2000 83.33607410721639 5.138820782235598e-06
j, loss,  2500 83.33524421391994 2.775559835299646e-07
j, loss,  3000 83.33519931961496 1.5047419310576515e-08
finished after 3072 iters
-----------------------------
j, loss,  0 82.82680324944411 inf
j, loss,  500 50.80210464974475 0.0008623053431833227
j, loss,  1000 50.673884473453874 5.7577423554278084e-05
j, loss,  1500 50.66024188371466 1.0917070099480952e-05
j, loss,  2000 50.65747284247999 2.324078295146137e-06
j, loss,  2500 50.65688128175265 4.975463525624946e-07
j, loss,  3000 50.65675463197924 1.06517

In [6]:
#bca[0]["vec"].T@bca[3]["vec"]

In [17]:
vecs = np.array([x["vec"] for x in bca])
for i,v in enumerate(vecs):
    for j, v2 in enumerate(vecs):
        if j <= i: continue
            
        print(i,j, v@v2.T)

0 1 7.627186184855361e-07
0 2 -1.1890638141193177e-06
0 3 -1.1041248385995628e-06
0 4 1.7216343892179076e-08
1 2 6.827629878830566e-07
1 3 4.641096717356019e-07
1 4 -6.809005553831682e-09
2 3 5.547453474294417e-07
2 4 -3.62159274097662e-09
3 4 2.282270400444375e-08


In [18]:

#for i in range(len(bca)):
    
    #print(bca[i]["explained_var"])
    
cov_out_total = get_cov_output_total(H,W)
total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()    
vars = [x["explained_var"] for x in bca]
print(vars)
print(sum(vars))

[28.415632737081403, 22.867610545707453, 19.103444285742405, 16.851800576081235, 12.761511669043047]
99.99999981365555


In [9]:
bca[-2]["cov_out"]

tensor([[ 1.1199e+01,  5.8626e+00,  5.9606e+00, -1.4767e+00, -3.9611e+00,
          2.7268e+00, -1.6588e+00, -3.9901e-01, -2.7961e+00,  3.2513e+00,
         -4.7000e+00, -3.8450e+00,  7.0242e-02,  2.3343e+00, -3.0070e-01,
          4.6083e+00],
        [ 5.8626e+00,  1.8849e+01,  6.1929e+00,  7.0691e-01,  3.3270e+00,
         -6.9513e+00, -2.6983e+00, -9.7823e-01,  3.0550e+00, -3.6983e-01,
         -8.7851e-01, -5.9634e-01, -2.4914e+00, -3.6286e+00, -6.8511e+00,
          3.4699e+00],
        [ 5.9606e+00,  6.1929e+00,  2.2912e+01,  1.8065e+00, -5.6866e+00,
         -4.6718e+00, -4.5252e+00,  6.4546e-03,  4.1944e-01, -2.1189e+00,
         -1.3973e+00, -3.3615e+00,  3.5891e+00, -7.3449e-01,  8.3607e+00,
         -1.3285e+00],
        [-1.4767e+00,  7.0691e-01,  1.8065e+00,  1.1346e+01,  5.2297e+00,
         -1.0216e+00,  4.5425e-01,  2.6010e+00,  7.1736e+00, -4.6572e+00,
          2.3475e+00,  4.8985e-01, -4.1403e+00,  4.2528e+00,  2.7935e+00,
         -3.8552e+00],
        [-3.9611e+00

In [10]:
bca[1]["explained_var"]

14.513608434725274

In [11]:
bca[0]["vec"].T@bca[2]["vec"]

3.199924110131036e-07

In [12]:
H.shape
u1, u2 = torch.randn(2,1), torch.randn(2,1)
u1 = u1 / torch.norm(u1)
u2 = u2 / torch.norm(u2)
P_u1 = P_u1 = torch.eye(2) - u1@u1.T
P_u2 = torch.eye(2) - u2@u2.T
P_u1= P_u1
P_u2 = P_u2
H@P_u1@P_u2

RuntimeError: size mismatch, m1: [1000000 x 32], m2: [2 x 2] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

In [None]:
P_u1 = torch.eye(2) - u1@u1.T
P_u2

In [None]:
torch.allclose