In [1]:
import torch
import numpy as np
from typing import List
import scipy
from scipy import linalg
from sklearn.decomposition import PCA
import pickle
torch.set_default_tensor_type(torch.DoubleTensor)

In [2]:
with open("encodings.bert-base.250k.pickle", "rb") as f:
    data = pickle.load(f)

In [3]:
H = np.array([d["vec_batch_norm"] for d in data])
with open("bert_embeddings.pickle", "rb") as f:
    W = pickle.load(f)
    
H,W = torch.from_numpy(H).double(), torch.from_numpy(W).double()
H = H[:100000]
W = W.T
H.shape, W.shape

(torch.Size([100000, 768]), torch.Size([768, 30522]))

In [5]:
# N = 100000
# d = 128
# l = 25
# H = (torch.randn(N,d) + torch.rand(N,d)**2)

# H = H - torch.mean(H, dim = 0, keepdim = True)
# H = H
# W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))
print(cov_H[0]@W)

tensor([0.5134, 0.5406, 0.4952,  ..., 0.4864, 0.3009, 0.5797])


In [6]:
del data

In [None]:
# import time
# start = time.time()
# Q = cov_H[:1000]@W
# print(Q)
# print(cov_H.shape, W.shape)
# print(time.time() - start)

In [None]:
# u = torch.nn.Parameter(u)
# optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [38]:
def get_cov_output_projected(u, cov_H, W):
    
    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    print(P.shape, W.shape, cov_H.shape, u_normed.shape)
    first = P@W
    print("done first")
    second = cov_H@first
    print("done second")
    third = P@second
    print("done third")
    fourth = W.T@third
    print("done fourth")
    return fourth
    #return W.T@P@cov_H@P@W

def get_cov_output_total(H,W):
    with torch.no_grad():
        Y_hat = H[:5000]@W 
        return torch.sum(Y_hat*Y_hat)
    #return torch.tensor(np.cov(Y_hat.detach().cpu().numpy(), rowvar = False))

def get_loss_func(cov_output_projected):
    
    loss =  torch.sum(torch.diag(cov_output_projected))
    print("done loss calculation")
    return loss

def get_loss_func2(u, cov_H, W):

    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    first = P@W
    second = cov_H@first
    third = P@second
    fourth = torch.sum(W*third)
    return fourth

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def get_first_pca(H):
    pca = PCA(n_components = 1)
    pca.fit(H)
    return torch.from_numpy(pca.components_.T)

def BCA(H,W,n_components, eps = 1e-8, max_iters = 1000, init_pca = True):
    
    P_nullspace = torch.eye(H.shape[1])
    results = []
    #cov_out_total = get_cov_output_total(H,W)
    total_var_orig = get_cov_output_total(H,W) #get_loss_func(cov_out_total).detach().cpu().numpy().item()
    remaining_var = total_var_orig
    print("Total var original: ", remaining_var)
    H_proj = H.clone()
    rowspace_projs = []
    
    for i in range(n_components):
        
        H_proj = H@P_nullspace # remove previous component 
        cov_H = torch.from_numpy(np.cov(H_proj.detach().cpu().numpy(), rowvar = False))
        
        if init_pca:
            u = get_first_pca(H_proj.detach().cpu().numpy())
        else:
            u = torch.randn(H_proj.shape[1], 1)
        u = torch.nn.Parameter(u)
        optimizer = torch.optim.SGD([u], lr=1e-4, momentum=0.8)
        #optimizer = torch.optim.Adam([u])
        
        diff = 10
        j = 0
        loss_vals = [np.inf]
        patience = 4
        patience_counter = 0
        
        while j < max_iters and patience_counter < patience:
            optimizer.zero_grad()
            #cov_out = get_cov_output_projected(u,cov_H,W)
            #loss = get_loss_func(cov_out)
            loss = get_loss_func2(u, cov_H, W)
            loss.backward()
            optimizer.step()
            loss_vals.append(loss.detach().cpu().numpy().item())
            diff = np.abs(loss_vals[-1] - loss_vals[-2])
            
            if diff > eps:
                patience_counter = 0
            else:
                patience_counter += 1
                
            if j % 25 == 0: print("j, loss, ", j, loss.detach().cpu().numpy().item(), diff)
            j += 1
        print("finished after {} iters".format(j))
        
        # calculate new nullspace projection to neutralzie component u
        
        u_normed = u / torch.norm(u)
        rowspace_projs.append((u_normed@u_normed.T).detach().cpu().numpy())
        P_nullspace = torch.from_numpy(get_projection_to_intersection_of_nullspaces(rowspace_projs,cov_H.shape[0]))
        #P_nullspace = torch.eye(H_proj.shape[1]).double() - u_normed@u_normed.T
        
        # calcualte explained variance
        #cov_out_total = get_cov_output_total(H,W)
        total_var = get_cov_output_total(H,W)
        #cov_out_projected = get_cov_output_projected(u,cov_H,W)
        #total_var_projected = get_loss_func(cov_out_projected).detach().cpu().numpy().item()
        total_var_projected = remaining_var - get_loss_func2(u,cov_H,W)
        explained_var = total_var_projected / total_var
        remaining_var = remaining_var - total_var_projected
        
        #u = u / u.norm()
        results.append({"vec": u.squeeze().detach().cpu().numpy(), "projected_var": total_var_projected,
                       "total_var": total_var, "explained_var": total_var_projected*100/total_var,
                       "cov_out":cov_out_projected})
    
    return results
        

In [40]:
bca = BCA(H,W,n_components=1, eps = 35, init_pca = True)

Total var original:  tensor(3.4003e+09)
j, loss,  0 155768.90582658348 inf
j, loss,  2 148063.25288920783 13594.808557405311
j, loss,  4 142304.65481574967 5933.130929292238
j, loss,  6 138352.90200163357 774.957252429449
j, loss,  8 137305.78105158376 717.4958084040263
j, loss,  10 135306.61000177864 985.1884706384444
j, loss,  12 133918.25313571922 604.7058947051992
j, loss,  14 132875.15765970555 517.2771117776283
j, loss,  16 131854.25553355517 500.9627363670443
j, loss,  18 130973.92237513408 418.5667270269478
j, loss,  20 130232.62488413988 357.93383542636
j, loss,  22 129576.05019058849 318.69090012906236
j, loss,  24 128998.81754779919 278.5579849306232
j, loss,  26 128496.3249894953 242.80060436206986
j, loss,  28 128054.33897504794 214.19489677906677
j, loss,  30 127662.79414672998 189.92334473092342
j, loss,  32 127314.43890327595 169.23853685667564
j, loss,  34 127002.05235633712 152.1201874371909
j, loss,  36 126719.4177254966 137.9319589875522
j, loss,  38 126461.55297522

NameError: name 'total_var_projected' is not defined

In [None]:
#bca[0]["vec"].T@bca[3]["vec"]

In [None]:
vecs = np.array([x["vec"] for x in bca])
for i,v in enumerate(vecs):
    for j, v2 in enumerate(vecs):
        if j <= i: continue
            
        print(i,j, v@v2.T)

In [None]:

#for i in range(len(bca)):
    
    #print(bca[i]["explained_var"])
    
cov_out_total = get_cov_output_total(H,W)
total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()    
vars = [x["explained_var"] for x in bca]
print(vars)
print(sum(vars))

In [None]:
bca[-1]["cov_out"]

In [None]:
bca[-3]["explained_var"]

In [None]:
u,v = np.random.randn(768), np.random.randn(768).T
u,v = u / np.linalg.norm(u), v/np.linalg.norm(v)
u@v

In [None]:
a = [1,2,3,4,5,6,7,8,9,10,11,12]

In [None]:
a[-12:-2]

In [None]:
A = np.random.rand(768,32000)
H = np.random.rand(10000,768)
H*A

In [None]:
W = torch.randn(768,32000)
H = torch.randn(10000, 768)
Y_hat = H[:10000]@W 
torch.mean(Y_hat*Y_hat)

In [None]:
print(Y_hat.shape)