In [1]:
import pickle
import numpy as np
import torch
from tqdm.auto import tqdm
from scipy.sparse.linalg import svds, LinearOperator
from pathlib import Path
from typing import Dict, List

In [2]:
def matvec(vec: np.ndarray, grad_vectors: np.ndarray) -> np.ndarray:
    """Compute the vectorized matvec operation.

    This function computes the sum over i of the transposed product:
    (grad_vectors[i].T @ V @ grad_vectors[i]).T, where V is vec reshaped
    into an (m, m) matrix using Fortran order.

    Args:
        vec: A numpy array of shape (m*m,).
        grad_vectors: A numpy array of shape (k, m, n).

    Returns:
        A numpy array of shape (n*n,).
    """
    k, m, n = grad_vectors.shape
    V = vec.reshape(m, m, order="F")


    result = np.einsum("ias,am,imr->sr", grad_vectors, V, grad_vectors)
    return result.T.ravel()

def r_matvec(vec: np.ndarray, grad_vectors: np.ndarray) -> np.ndarray:
    """Compute the vectorized r_matvec operation.

    This function computes the sum over i of the transposed product:
    (grad_vectors[i] @ V @ grad_vectors[i].T).T, where V is vec reshaped
    into an (n, n) matrix using Fortran order.

    Args:
        vec: A numpy array of shape (n*n,).
        grad_vectors: A numpy array of shape (k, m, n).

    Returns:
        A numpy array of shape (m*m,).
    """
    
    k, m, n = grad_vectors.shape
    V = vec.reshape(n, n, order="F")
    result = np.einsum("iqc,cd,ipd->pq", grad_vectors, V, grad_vectors)
    return result.ravel()

def load_weight(path: Path) -> Dict[str, torch.Tensor]:
    """Load gradients from a pickle file."""
    with path.open("rb") as f:
        grads = pickle.load(f)
    return grads


In [3]:
weights = load_weight(Path("fishers/tensor_grad57_out.pickle"))

In [4]:
key = f"bert.encoder.layer.0.intermediate.dense"
l_grads: List[torch.Tensor] = weights[key]
m, n = l_grads[0].shape

# grad_vectors = torch.stack([grad_tensor.reshape(-1) for grad_tensor in l_grads])

l_gradss = [grad.reshape(-1).numpy() for grad in l_grads]
grad_vectors = np.stack([grad.reshape(n,m, order = 'F') for grad in l_gradss])

In [5]:
grad_vectors.shape

(268, 768, 3072)

In [None]:
%%time 

linop_m = LinearOperator(
    shape=(m**2, n**2),
    matvec=lambda x: matvec(x, grad_vectors),
    rmatvec=lambda x: r_matvec(x, grad_vectors) 
)

u, s, vt = svds(linop_m, k=1, return_singular_vectors=True)

In [None]:
sidx = np.argsort(-s)
s = s[sidx]
u = u[:, sidx]
v = vt[sidx, :].T

x_ = u[:, 0] * s[0]
y_ = v[:, 0]
XF = x_.reshape(m, m, order="F")
YF = y_.reshape(n, n, order="F")