In [1]:
device = "cuda:6"

### Preliminaries

In [2]:
import itertools
import random
import collections


import transformers
import torch
import tqdm.auto
from torch import Tensor

In [3]:
def sinusoidal_encode(
    x: Tensor,
    embedding_dim: int,
    min_value: int,
    max_value: int,
    use_l2_norm: bool = False,
    norm_const: float | None = None,
) -> Tensor:
    """
    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional
    encoding works in transformers.

    The encoding is an evaluation of a sine and cosine function at different frequencies, where the
    frequency is determined by the embedding dimension and the allowed range of the input values.

    >>> sinusoidal_encode(
    ...     torch.tensor([-5, 2, 1, 0]),
    ...     embedding_dim=6,
    ...     min_value=-5,
    ...     max_value=5,
    ... )
    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],
            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],
            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])
    """

    if embedding_dim % 2 != 0 and not use_l2_norm:
        raise ValueError("Embedding dimension must be even")

    if use_l2_norm:
        if embedding_dim % 2 == 0:
            reserved_dim = 2
        else:
            reserved_dim = 1
        embedding_dim -= reserved_dim
    else:
        reserved_dim = 0  # will not be used

    domain = max_value - min_value
    y_shape = x.shape + (embedding_dim,)
    y = torch.zeros(y_shape, device=x.device)
    even_indices = torch.arange(0, embedding_dim, 2)
    log_term = torch.log(torch.tensor(domain)) / embedding_dim
    div_term = torch.exp(even_indices * -log_term)
    x = x - min_value
    values = x.unsqueeze(-1).float() * div_term
    y[..., 0::2] = torch.sin(values)
    y[..., 1::2] = torch.cos(values)

    if use_l2_norm:
        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)
        y /= y.norm(dim=-1, keepdim=True, p=2)

    if norm_const is not None:
        y *= norm_const

    return y

### Prepare model and data

In [4]:
model = transformers.AutoModel.from_pretrained("meta-llama/Llama-3.2-1B").eval()
tokenizer = transformers.AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = model.to(device).eval()

In [5]:
all_values = torch.arange(0, 1000)
mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))
train_mask = mask < 0.9
valid_mask = ~train_mask & (mask < 0.95)
test_mask = ~train_mask & ~valid_mask

train_values = all_values[train_mask]
valid_values = all_values[valid_mask]
test_values = all_values[test_mask]

In [6]:
def make_str_input(nums: list) -> str:
    return str(nums[0]) + "".join(str(n).zfill(3) for n in nums[1:])

make_str_input([3, 500, 789]), make_str_input([3, 0, 1])

('3500789', '3000001')

In [7]:
rng = random.Random(0)

train_size = 200_000

x_values_train = [(i, j) for i, j in zip(
    rng.choices(train_values.tolist(), k=train_size),
    rng.choices(train_values.tolist(), k=train_size)
)]
x_values_valid = list(itertools.product(valid_values.tolist(), repeat=2))
x_values_test = list(itertools.product(test_values.tolist(), repeat=2))

x_inputs_valid = tokenizer(list(map(make_str_input, x_values_valid)), return_tensors="pt")
x_inputs_valid = tokenizer(list(map(make_str_input, x_values_test)), return_tensors="pt")

In [8]:
def get_hidden_states(model, str_inputs: list[str]) -> collections.defaultdict[int, Tensor]:
    model.eval()
    hidden_states = collections.defaultdict(list)
    with torch.no_grad():
        for batch_str in itertools.batched(tqdm.auto.tqdm(str_inputs), n=128):
            batch_inputs = tokenizer(batch_str, return_tensors="pt")
            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states
            for layer_idx, hidden_state in enumerate(hidden_reprs):
                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())
    return {k: torch.stack(v) for k, v in hidden_states.items()}

In [9]:
train_hidden_states = get_hidden_states(model, list(map(make_str_input, x_values_train)))
valid_hidden_states = get_hidden_states(model, list(map(make_str_input, x_values_valid)))
test_hidden_states = get_hidden_states(model, list(map(make_str_input, x_values_test)))

  0%|          | 0/200000 [00:00<?, ?it/s]

  0%|          | 0/3249 [00:00<?, ?it/s]

  0%|          | 0/3025 [00:00<?, ?it/s]

### Probing

In [10]:
basis_embs = sinusoidal_encode(
    torch.arange(1000),
    min_value=0,
    max_value=1000,
    embedding_dim=train_hidden_states[0].shape[-1],
)

In [11]:
class ClassifierProbe(torch.nn.Module):
    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):
        super().__init__()
        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)
        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)
        self.basis: torch.nn.Buffer
        self.heldout_mask: torch.nn.Buffer
        self.register_buffer("basis", basis)
        self.register_buffer("heldout_mask", heldout_mask)
    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:
        latent_x = self.emb_to_latent(x)
        # during training, model learns to choose among only training tokens
        # but during eval, model must choose among all tokens
        # this means that the model is never exposed to the eval tokens during training
        latent_choices = self.basis_to_latent(self.basis)
        logits = latent_x @ latent_choices.T
        if holdout_eval_tokens:
            logits[:, self.heldout_mask] = float("-inf")
        return logits

In [13]:
torch.manual_seed(0)
probe = ClassifierProbe(
    emb_dim=train_hidden_states[0].shape[-1],
    hidden_dim=100,
    basis=basis_embs,
    heldout_mask=test_mask,
).to(device)

layer_idx = 5 # choose the layer to probe
optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)
train_labels = torch.tensor([x[0] for x in x_values_train]) # we want to decode the first number token from hidden representation of the last token
rng = torch.Generator().manual_seed(0)
for i in range(10000):
    probe.train()
    optimizer.zero_grad()
    minibatch_idcs = torch.randint(len(train_labels), size=(256,), generator=rng)
    x = train_hidden_states[layer_idx][minibatch_idcs].to(device)
    y = train_labels[minibatch_idcs].to(device)
    logits = probe(x, holdout_eval_tokens=True)
    # add l1 regularization of all params to the loss
    loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())
    loss.backward()
    optimizer.step()
    if i % 500 == 0:
        train_acc = (logits.argmax(dim=-1) == y).float().mean()
        probe.eval()
        with torch.no_grad():
            valid_logits = probe(valid_hidden_states[layer_idx].to(device), holdout_eval_tokens=False)
            valid_labels = torch.tensor([x[0] for x in x_values_valid]).to(device)
            valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)
            accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean()
        print(f"{i=:<5} train loss: {loss.item():.2f} train acc: {train_acc.item():.2f}  val loss: {valid_loss.item():.2f} valid acc: {accuracy.item():.2f}")
            

i=0     train loss: 11.41 train acc: 0.00  val loss: 6.88 valid acc: 0.00
i=500   train loss: 2.04 train acc: 0.84  val loss: 1.28 valid acc: 0.77
i=1000  train loss: 1.62 train acc: 0.88  val loss: 1.09 valid acc: 0.70
i=1500  train loss: 1.46 train acc: 0.92  val loss: 0.94 valid acc: 0.80
i=2000  train loss: 1.39 train acc: 0.92  val loss: 0.98 valid acc: 0.73
i=2500  train loss: 1.30 train acc: 0.96  val loss: 0.96 valid acc: 0.79
i=3000  train loss: 1.27 train acc: 0.95  val loss: 0.95 valid acc: 0.78
i=3500  train loss: 1.25 train acc: 0.94  val loss: 0.81 valid acc: 0.80
i=4000  train loss: 1.20 train acc: 0.93  val loss: 0.83 valid acc: 0.80
i=4500  train loss: 1.25 train acc: 0.89  val loss: 0.89 valid acc: 0.81
i=5000  train loss: 1.18 train acc: 0.94  val loss: 0.78 valid acc: 0.79
i=5500  train loss: 1.14 train acc: 0.93  val loss: 0.77 valid acc: 0.81
i=6000  train loss: 1.09 train acc: 0.96  val loss: 0.81 valid acc: 0.80
i=6500  train loss: 1.13 train acc: 0.93  val loss