In [12]:
import torch
import numpy as np
from typing import List
import scipy
from scipy import linalg
from sklearn.decomposition import PCA
import pickle
import tqdm
torch.set_default_tensor_type(torch.DoubleTensor)


In [42]:
with open("encodings.bert-base.250k.pickle", "rb") as f:
    data = pickle.load(f)
    
sents = np.array([d["sent"] for d in data])

In [47]:
sents = np.array(sents)
masked_tokens = np.array([d["gold"] for d in data])
top_words = np.array([d["top_words"][0] for d in data])

In [51]:
print(sents[2])
print(masked_tokens[2])
print(top_words[2])

The Game ) is a 2003 Bollywood Action film directed by Yusuf Khan .
is
is


In [14]:
H = np.array([d["vec_batch_norm"] for d in data])
with open("bert_embeddings.pickle", "rb") as f:
    W = pickle.load(f)
    
H,W = torch.from_numpy(H).double(), torch.from_numpy(W).double()
H = H[:100000]
W = W.T
H.shape, W.shape


(torch.Size([100000, 768]), torch.Size([768, 30522]))

In [15]:
# N = 100000
# d = 128
# l = 25
# H = (torch.randn(N,d) + torch.rand(N,d)**2)

# H = H - torch.mean(H, dim = 0, keepdim = True)
# H = H
# W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))
print(cov_H[0]@W)

tensor([0.5134, 0.5406, 0.4952,  ..., 0.4864, 0.3009, 0.5797])


In [52]:
del data

In [6]:
# import time
# start = time.time()
# Q = cov_H[:1000]@W
# print(Q)
# print(cov_H.shape, W.shape)
# print(time.time() - start)

In [7]:
# u = torch.nn.Parameter(u)
# optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [65]:
def get_cov_output_projected(u, cov_H, W):
    
    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    print(P.shape, W.shape, cov_H.shape, u_normed.shape)
    first = P@W
    print("done first")
    second = cov_H@first
    print("done second")
    third = P@second
    print("done third")
    fourth = W.T@third
    print("done fourth")
    return fourth
    #return W.T@P@cov_H@P@W

def get_cov_output_total(H,W,n=10000):
    with torch.no_grad():
        Y_hat = H[:n]@W 
        Y_hat = Y_hat - torch.mean(Y_hat, dim = 1, keepdim = True)
        return torch.sum(Y_hat*Y_hat)/Y_hat.shape[0]
    #return torch.tensor(np.cov(Y_hat.detach().cpu().numpy(), rowvar = False))

def eval_total_var(H,W):
    
    with torch.no_grad():
        k = 2000
        Y_sum = torch.zeros(W.shape[1])
        Y_sqr_sum = torch.zeros(W.shape[1])
    
        for i in range(0,len(H), k):
            Y = H[i:i+k]@W
            Y_sum += torch.sum(Y, dim = 0)
            Y_sqr_sum += torch.sum(Y**2, dim = 0)
    
        Y_sqr_sum /= len(H)
        Y_sum /= len(H)
        return (Y_sqr_sum - Y_sum**2).mean() # mean over vocab
    
def get_loss_func(cov_output_projected):
    
    loss =  torch.sum(torch.diag(cov_output_projected))
    print("done loss calculation")
    return loss

def get_loss_func2(u, cov_H, W):

    u_normed = u / torch.norm(u)
    P = torch.eye(cov_H.shape[0]) - (u_normed@u_normed.T)
    #P = u_normed@u_normed.T
    first = P@W
    second = cov_H@first
    third = P@second
    fourth = torch.sum(W*third)/third.shape[1] # mean over vocab
    return fourth 

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def get_first_pca(H):
    pca = PCA(n_components = 1)
    pca.fit(H)
    return torch.from_numpy(pca.components_.T)

def BCA(H,W,n_components, eps = 1e-8, max_iters = 1000, min_iters = 150, init_pca = True):
    
    P_nullspace = torch.eye(H.shape[1])
    results = []
    #cov_out_total = get_cov_output_total(H,W)
    total_var_orig = eval_total_var(H,W) #get_loss_func(cov_out_total).detach().cpu().numpy().item()
    remaining_var = total_var_orig
    print("Total var original: ", remaining_var)
    H_proj = H.clone()
    rowspace_projs = []
    
    for i in range(n_components):
        
        H_proj = H@P_nullspace # remove previous component 
        cov_H = torch.from_numpy(np.cov(H_proj.detach().cpu().numpy(), rowvar = False))
        
        if init_pca:
            u = get_first_pca(H_proj.detach().cpu().numpy())
        else:
            u = torch.randn(H_proj.shape[1], 1)
        u = torch.nn.Parameter(u)
        optimizer = torch.optim.SGD([u], lr=3*1e-2, momentum=0.25)
        #optimizer = torch.optim.Adam([u])
        
        diff = 10
        j = 0
        loss_vals = [np.inf]
        patience = 10
        patience_counter = 0
        
        while j < max_iters and patience_counter < patience:
            optimizer.zero_grad()
            #cov_out = get_cov_output_projected(u,cov_H,W)
            #loss = get_loss_func(cov_out)
            loss = get_loss_func2(u, cov_H, W)
            loss.backward()
            optimizer.step()
            loss_vals.append(loss.detach().cpu().numpy().item())
            diff = np.abs(loss_vals[-1] - loss_vals[-2])
            
            if diff > eps:
                patience_counter = 0
            else:
                if j > min_iters:
                    patience_counter += 1
                
            if j % 25 == 0: print("j, loss, ", j, loss.detach().cpu().numpy().item(), diff)
            j += 1
        print("finished after {} iters".format(j))
        
        # calculate new nullspace projection to neutralzie component u
        
        u_normed = u / torch.norm(u)
        rowspace_projs.append((u_normed@u_normed.T).detach().cpu().numpy())
        P_nullspace = torch.from_numpy(get_projection_to_intersection_of_nullspaces(rowspace_projs,cov_H.shape[0]))
        #P_nullspace = torch.eye(H_proj.shape[1]).double() - u_normed@u_normed.T
        
        # calcualte explained variance
        #cov_out_total = get_cov_output_total(H,W)
        total_var = eval_total_var(H,W)
        #cov_out_projected = get_cov_output_projected(u,cov_H,W)
        #total_var_projected = get_loss_func(cov_out_projected).detach().cpu().numpy().item()
        total_var_projected = remaining_var - get_loss_func2(u,cov_H,W)
        explained_var = total_var_projected / total_var_orig
        remaining_var = remaining_var - total_var_projected
        
        u = u / u.norm()
        results.append({"vec": u.squeeze().detach().cpu().numpy(), "projected_var": total_var_projected,
                       "total_var": total_var, "explained_var": total_var_projected*100})
    
    return results
        

In [None]:
bca = BCA(H,W,n_components=3, eps = 0.5*1e-4, init_pca = True)

Total var original:  tensor(5.6453)
j, loss,  0 5.103504540583716 inf
j, loss,  25 4.563359069997909 0.01486304000367511
j, loss,  50 4.263849139912718 0.00902478341531232
j, loss,  75 4.100062379082116 0.004877269240622262
j, loss,  100 4.000624601050575 0.003356657719244538
j, loss,  125 3.92816104533149 0.002518444157333377
j, loss,  150 3.874233945310986 0.0018436812287796478
j, loss,  175 3.835143196388439 0.0013372826777038327
j, loss,  200 3.8062502824568307 0.0010118175164461896
j, loss,  225 3.783904661746057 0.0007975775978374955
j, loss,  250 3.766061339880651 0.0006434790743754526
j, loss,  275 3.7515728028505038 0.0005250326030097341
j, loss,  300 3.7397207665125043 0.0004302122285397836
j, loss,  325 3.730005061917473 0.00035265956841534774
j, loss,  350 3.722045472807963 0.00028870560646998555
j, loss,  375 3.715534494587027 0.0002359969256486849
j, loss,  400 3.7102141629080663 0.00019281820038541397
j, loss,  425 3.7058651931832403 0.00015773935132479266
j, loss,  450 

In [None]:
#bca[0]["vec"].T@bca[3]["vec"]

In [19]:
vecs = np.array([x["vec"] for x in bca])
for i,v in enumerate(vecs):
    for j, v2 in enumerate(vecs):
        if j <= i: continue
            
        print(i,j, v@v2.T)

0 1 0.00011314179064859584
0 2 -0.00017821977203982126
1 2 0.0002278588081749272


In [20]:
np.linalg.norm(bca[0]["vec"])

1.0

In [24]:

#for i in range(len(bca)):
    
    #print(bca[i]["explained_var"])
    
#cov_out_total = get_cov_output_total(H,W)
#total_var = get_loss_func(cov_out_total).detach().cpu().numpy().item()    
vars = [x["projected_var"] for x in bca]
print(vars)
print(sum(vars))

[tensor(1.9592, grad_fn=<SubBackward0>), tensor(0.1169, grad_fn=<SubBackward0>), tensor(0.1146, grad_fn=<SubBackward0>)]
tensor(2.1907, grad_fn=<AddBackward0>)


In [60]:
u = bca[0]["vec"]
#u = np.expand_dims(u, axis=1)
vecs = H.detach().cpu().numpy()
vecs_u = vecs@u
idx = np.argsort(vecs_u)
k = 50
top_neg, top_pos = idx[:k], idx[-k:]

sents_pos = sents[top_pos]
sents_neg = sents[top_neg]
gold_pos = masked_tokens[top_pos]
gold_neg = masked_tokens[top_neg]
preds_pos = top_words[top_pos]
preds_neg = top_words[top_neg]

for i in range(10):
    print(gold_pos[i], " -------- ", sents_pos[i], "---------", vecs_u[idx[i]])
    print("======================================")
    
print("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&")

for i in range(10):
    print(gold_neg[i], " -------- ", sents_neg[i])
    print("======================================")

,  --------  The municipality of Berra contains the frazioni ( subdivisions , mainly villages and hamlets ) Cologna and Serravalle . --------- -5.448562479435196
Be'er  --------  Located near Kiryat Malakhi with 98 farms covering an area of 6 , 000 dunams , it falls under the jurisdiction of Be'er Tuvia Regional Council . --------- -4.6270222685213325
,  --------  Hamma is a village and a former municipality in the Eichsfeld district , in Thuringia , Germany . --------- -4.401286464200628
,  --------  Nochern is a municipality and village in the district of Rhein-Lahn , in Rhineland-Palatinate , in western Germany . --------- -4.29506809569322
the  --------  At the time of NRHP listing it was owned by Kennecott Copper Corporation . --------- -4.256045699116503
in  --------  As a designated place in the 2011 Census , Sunset Acres had a population of 61 living in 22 of its 22 total dwellings , a -3 . --------- -4.186547405002317
by  --------  Homeobox protein DLX-5 is a protein that in h

In [None]:
bca[-1]["cov_out"]

In [None]:
bca[-3]["explained_var"]

In [None]:
u,v = np.random.randn(768), np.random.randn(768).T
u,v = u / np.linalg.norm(u), v/np.linalg.norm(v)
u@v

In [None]:
a = [1,2,3,4,5,6,7,8,9,10,11,12]

In [None]:
a[-12:-2]

In [None]:
A = np.random.rand(768,32000)
H = np.random.rand(10000,768)
H*A

In [None]:
W = torch.randn(768,32000)
H = torch.randn(10000, 768)
Y_hat = H[:10000]@W 
torch.mean(Y_hat*Y_hat)

In [13]:
H.shape, W.shape

(torch.Size([100000, 768]), torch.Size([768, 30522]))

In [23]:
H_first_comp = get_first_pca(H)
Y = H[:10000]@W
Y_first_comp = get_first_pca(Y)

In [105]:
total_var = eval_total_var(H,W)
print(total_var)

tensor([4.4340, 4.4635, 4.4105,  ..., 3.8644, 3.4939, 4.9670])


In [108]:
print(total_var.sum()/W.shape[1])

tensor(5.6453)


In [122]:
u = torch.randn(768,1)
get_loss_func2(H_first_comp, cov_H, W)

tensor(5.1035)

In [119]:
5%1

0