In [None]:
import torch

In [None]:
import safetensors
with safetensors.safe_open("./model_embs_cache/meta-llama__Llama-3.2-1B.safetensors", framework="pt") as f:
    llama_embs = f.get_tensor("embs")

In [None]:
import plotly
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10), dpi=100)

plt.imshow(
    llama_embs.T.corrcoef().numpy(),
    cmap='viridis'
)

In [None]:
import plotly.express

plotly.express.histogram(
    (llama_embs.T.corrcoef().abs() - torch.eye(llama_embs.shape[1])).flatten(),
    nbins=100,
)

In [None]:
def multicolinear_corrcoef(x: torch.Tensor) -> torch.Tensor:
    corrs = []
    n, d = x.shape
    for i in range(d):
        x_i = x[:, i]
        x_rest = torch.cat([x[:, :i], x[:, i+1:]], dim=1)
        beta = torch.linalg.lstsq(x_i.unsqueeze(1), x_rest).solution
        x_i_pred = x_rest @ beta.T
        corrs.append(torch.corrcoef(torch.stack([x_i, x_i_pred.flatten()]))[0, 1].item())
    return torch.tensor(corrs)

In [None]:
multicolinear_corrcoef(llama_embs)

In [None]:
plotly.express.bar(
    multicolinear_corrcoef(llama_embs),
)

In [None]:
plotly.express.histogram(
    multicolinear_corrcoef(llama_embs)
)