In [None]:
import transformers
import plotly.express
import torch
import sklearn.decomposition
from torch import Tensor

def pca(embs: Tensor, low_dim: int) -> Tensor:
    pca = sklearn.decomposition.PCA(n_components=low_dim)
    reduced_embs = pca.fit_transform(embs.detach().numpy())
    return torch.tensor(reduced_embs)

def fourier(embs: Tensor) -> Tensor:
    return torch.fft.fft(embs, dim=0)

def vis_emb_plotly(embs: Tensor, title: str) -> plotly.graph_objects.Figure:
    fig = plotly.express.imshow(
        embs.cpu().T.detach(),
        color_continuous_scale="blues",
        aspect='auto',
    )
    return fig.update_layout(title=title).update_xaxes(title="Token Value").update_yaxes(title="Feature")

llama1b = transformers.AutoModel.from_pretrained("meta-llama/Llama-3.2-1B").eval()
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
llama1b = llama1b.eval()

inputs_str = [
    f"{x} + 1 = " for x in range(0, 1000)
]
inputs = tokenizer(inputs_str, return_tensors="pt")

with torch.no_grad():
    outputs = llama1b(**inputs, output_hidden_states=True)


In [None]:
for hidden_state_layer_idx in [10, 11, 12, 13]:
    hidden_states_last_token = outputs.hidden_states[hidden_state_layer_idx][:, -1, :]
    display(vis_emb_plotly(
        pca(hidden_states_last_token.cpu(), 16),
        f"hidden states in layer {hidden_state_layer_idx} of LLama 1B"
    ))