In [280]:
import torch
from torch import autograd
import torch.nn
import numpy
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline

In [342]:
N = 1000
d_x = 10
h = 5
d_y = 10

In [351]:
X = torch.rand(N, d_x) * 2 -1
W = torch.randn(d_x, d_y)
Y = X @ W


class LinearNet(torch.nn.Module):
    def __init__(self, d_x, h, d_y):
        # autoregressive model of order k
        super(LinearNet, self).__init__()
        self.A = torch.nn.Parameter(torch.randn(h, d_y))
        self.B = torch.nn.Parameter(torch.randn(d_x, h))
        #self.A.requires_grad = True
        #self.B.requires_grad = True
    
    def forward(self, input):
        return input @ self.B @ self.A

def compute_loss(X, Y, model):
    return torch.nn.functional.mse_loss(model(X), Y, reduction='sum') * 0.5 / X.size(0)

def eval_grad(X, Y, model):
    model.zero_grad()
    loss = compute_loss(X,Y,model)
    loss.backward()
    grad_A = model.A.grad.data
    grad_B = model.B.grad.data
    return torch.cat([grad_A.view(-1), grad_B.view(-1)])


# eval Hessian matrix
def eval_hessian(X, Y, model):
    loss = compute_loss(X, Y, model)
    loss_grad = autograd.grad(loss, model.parameters(), create_graph=True)
    cnt = 0
    for g in loss_grad:
        g_vector = g.contiguous().view(-1) if cnt == 0 else torch.cat([g_vector, g.contiguous().view(-1)])
        cnt = 1
    l = g_vector.size(0)
    hessian = torch.zeros(l, l)
    for idx in range(l):
        grad2rd = autograd.grad(g_vector[idx], model.parameters(), create_graph=True)
        cnt = 0
        for g in grad2rd:
            g2 = g.contiguous().view(-1) if cnt == 0 else torch.cat([g2, g.contiguous().view(-1)])
            cnt = 1
        hessian[idx] = g2
    return hessian.data

model = LinearNet(d_x, h, d_y)

In [352]:
Sigma_XX = X.t() @ X / X.size(0)
Sigma_YX = Y.t() @ X / X.size(0)

Sigma = Sigma_YX @ torch.inverse(Sigma_XX) @ Sigma_YX.t()
s, U = torch.symeig(Sigma, eigenvectors=True)
print(s)

tensor([ 0.0280,  0.1337,  0.3934,  0.7907,  2.0025,  3.4378,  4.0731,  4.5526,
         7.3233, 11.4486])


In [353]:
indices = torch.LongTensor([1,2,3,4,5])
U_hat = torch.index_select(U, dim=1, index=indices)
model.A.data = U_hat.t()
model.B.data = torch.inverse(Sigma_XX) @ Sigma_YX.t() @ U_hat

In [354]:
eval_grad(X, Y, model).norm()

tensor(2.5235e-06)

In [355]:
H = eval_hessian(X, Y, model)

In [356]:
torch.symeig(H)[0].min()

tensor(-1.6968)

In [359]:
s_H, U_H = torch.symeig(H, eigenvectors=True)
v = U_H[:, 0].unsqueeze(1)

corr = 0
batch_size = 10
n_batches = N / batch_size
Sigma_g = 0
for i in range(0, N, batch_size):
    x, y = X[i:i+batch_size, :], Y[i:i+batch_size, :]
    g = eval_grad(x, y, model).unsqueeze(1)
    corr += (g.t() @ v).item()**2 / n_batches
    Sigma_g += g @ g.t() / n_batches

print(corr)
s_g, _ = torch.symeig(Sigma_g)

0.09891105472979152


In [360]:
s_g

tensor([-2.3906e-07, -1.7494e-07, -1.4929e-07, -1.0468e-07, -1.0066e-07,
        -7.7774e-08, -7.0437e-08, -5.3709e-08, -4.9904e-08, -4.5665e-08,
        -3.8384e-08, -3.4583e-08, -3.0678e-08, -2.0996e-08, -1.5755e-08,
        -1.3425e-08, -1.1622e-08, -8.0070e-09, -5.6599e-09, -2.5137e-09,
        -2.0455e-09, -1.4807e-09, -1.7743e-10, -1.5230e-10, -1.3000e-14,
        -6.2604e-15, -4.5505e-15, -3.4909e-15, -1.8968e-15, -1.0655e-15,
        -2.8005e-16, -4.5582e-17,  8.4531e-16,  1.3600e-15,  1.6731e-15,
         2.0102e-15,  2.2287e-15,  2.7159e-15,  3.1691e-15,  4.2579e-15,
         4.7102e-15,  5.2120e-15,  5.2915e-15,  7.1228e-15,  8.1337e-15,
         9.6185e-15,  1.1578e-14,  1.4265e-14,  1.1825e-13,  1.5516e-11,
         2.8647e-10,  9.7321e-10,  1.2821e-09,  1.9528e-09,  4.5618e-09,
         5.1062e-09,  5.7166e-09,  7.1300e-09,  1.0208e-08,  1.2594e-08,
         1.3892e-08,  1.7552e-08,  1.8469e-08,  2.1486e-08,  2.8210e-08,
         3.2055e-08,  3.3710e-08,  4.1189e-08,  5.0