In [2]:
import pickle
import numpy as np
import torch
from tqdm.auto import tqdm
from scipy.sparse.linalg import svds, LinearOperator
from pathlib import Path
from typing import Dict, List
from matvecs import matvec_old, r_matvec_old

In [3]:
def load_weight(path: Path) -> Dict[str, torch.Tensor]:
    """Load gradients from a pickle file."""
    with path.open("rb") as f:
        grads = pickle.load(f)
    return grads


In [4]:
weights = load_weight(Path("fishers/tensor_grad57_out.pickle"))

In [13]:
key = f"bert.encoder.layer.0.output.dense"
l_grads: List[torch.Tensor] = weights[key]
m, n = l_grads[0].shape

# grad_vectors = torch.stack([grad_tensor.reshape(-1) for grad_tensor in l_grads])

l_gradss = [grad.reshape(-1).numpy() for grad in l_grads]
grad_vectors = np.stack([grad.reshape(n,m, order = 'F') for grad in l_gradss])

In [14]:
linop_m = LinearOperator(
    shape=(m**2, n**2),
    matvec=lambda x: matvec_old(x, grad_vectors),
    rmatvec=lambda x: r_matvec_old(x, grad_vectors) 
)

u, s, vt = svds(linop_m, k=1, return_singular_vectors=True)

sidx = np.argsort(-s)
s = s[sidx]
u = u[:, sidx]
v = vt[sidx, :].T

x_ = u[:, 0] * s[0]
y_ = v[:, 0]
XF = x_.reshape(m, m, order="F")
YF = y_.reshape(n, n, order="F")

np.save("fisher_factors/kron_factors_int/C1sgd_0_output.npy", XF)
np.save("fisher_factors/kron_factors_int/B1sgd_0_output.npy", YF)