In [444]:
import torch
import numpy as np

In [445]:
N = 1000
d = 32
l = 64
H = torch.randn(N,d) + torch.rand(N,d)**2

H = H - torch.mean(H, dim = 0, keepdim = True)
W = torch.randn(d,l)
cov_H = torch.tensor(np.cov(H.detach().cpu().numpy(), rowvar = False))


In [446]:
u = torch.randn(d,1)

In [447]:
u = torch.nn.Parameter(u)
optimizer = torch.optim.SGD([u], lr=0.001, momentum=0.9, weight_decay=1e-4)

In [416]:
def get_cov_output_projected(u, cov_H, W, project = True):
    
    u_normed = u / torch.norm(u)
    if project:
        P = u_normed@u_normed.T
    else:
        P = torch.eye(cov_H.shape[0])
    return W.T@P@cov_H@P@W

def get_loss_func(cov_output_projected):
    
    return -torch.sum(torch.diag(cov_output_projected))

In [408]:
cov_out = get_cov_output_projected(u.float(),cov_H.float(),W.float())
get_loss_func(cov_out)

tensor(-56.2016, grad_fn=<NegBackward>)

In [409]:
cov_out

tensor([[ 1.6231,  1.2822,  1.8030,  ...,  1.1869, -0.3423,  1.0715],
        [ 1.2822,  1.0129,  1.4243,  ...,  0.9376, -0.2704,  0.8464],
        [ 1.8030,  1.4243,  2.0028,  ...,  1.3185, -0.3802,  1.1902],
        ...,
        [ 1.1869,  0.9376,  1.3185,  ...,  0.8680, -0.2503,  0.7835],
        [-0.3423, -0.2704, -0.3802,  ..., -0.2503,  0.0722, -0.2260],
        [ 1.0715,  0.8464,  1.1902,  ...,  0.7835, -0.2260,  0.7073]],
       grad_fn=<MmBackward>)

In [410]:
u_normed = u / torch.norm(u)
P = u_normed@u_normed.T
Y = H@P@W
torch.Tensor(np.cov(Y.detach().cpu().numpy(), rowvar=False))

tensor([[ 1.6231,  1.2822,  1.8030,  ...,  1.1869, -0.3423,  1.0715],
        [ 1.2822,  1.0129,  1.4243,  ...,  0.9376, -0.2704,  0.8464],
        [ 1.8030,  1.4243,  2.0028,  ...,  1.3185, -0.3802,  1.1902],
        ...,
        [ 1.1869,  0.9376,  1.3185,  ...,  0.8680, -0.2503,  0.7835],
        [-0.3423, -0.2704, -0.3802,  ..., -0.2503,  0.0722, -0.2260],
        [ 1.0715,  0.8464,  1.1902,  ...,  0.7835, -0.2260,  0.7073]])

In [449]:
for i in range(800):
    optimizer.zero_grad()
    cov_out = get_cov_output_projected(u,cov_H.float(),W)
    loss = get_loss_func(cov_out)
    loss.backward()
    optimizer.step()
    
    if i%100 == 0:
        print(loss.detach().cpu().numpy().item())

-63.87030029296875
-186.16481018066406
-188.79730224609375
-189.3166046142578
-189.4493865966797
-189.48304748535156
-189.4915771484375
-189.49371337890625
-189.4941864013672
-189.4943389892578
-189.49447631835938
-189.4944305419922
-189.49436950683594
-189.49447631835938
-189.49444580078125
-189.49435424804688
-189.49449157714844
-189.49436950683594
-189.49432373046875
-189.494384765625


In [413]:
u = u / torch.norm(u)
WT = W.T/torch.norm(W.T, dim = 1, keepdim = True)
WT@u.squeeze()

tensor([-0.2677, -0.0253, -0.0333, -0.1426,  0.4683,  0.4370, -0.0010, -0.2550,
         0.3309,  0.0120, -0.0158,  0.3556,  0.0782,  0.1854, -0.2046,  0.3055,
         0.0177, -0.0521,  0.3623, -0.0094, -0.3348,  0.4240,  0.2525, -0.3816,
         0.1589, -0.4915,  0.0954, -0.0662, -0.4441, -0.0526,  0.0868, -0.6774,
        -0.3135,  0.0286,  0.1081,  0.3443, -0.0087, -0.2987, -0.2919, -0.2000,
         0.3860, -0.1248, -0.2338,  0.2126,  0.1587, -0.2507,  0.3060,  0.3181,
         0.0126, -0.6031,  0.0423, -0.1774,  0.0869,  0.3745, -0.0361, -0.2441,
         0.1190,  0.4289, -0.2458, -0.2765,  0.1691,  0.0349, -0.0205,  0.1098],
       grad_fn=<MvBackward>)

In [450]:
cov_out_total = get_cov_output_projected(u,cov_H.float(),W, project = False)
total_var = total_var_projected = -get_loss_func(cov_out_total)
cov_out_projected = get_cov_output_projected(u,cov_H.float(),W, project = True)
total_var_projected = -get_loss_func(cov_out_projected)
explained_var = total_var_projected / total_var
print("Explained variance in output: {}%".format(explained_var*100))

Explained variance in output: 8.756929397583008%
