In [1]:
device = "cuda:7"

### Preliminaries

In [2]:
import itertools
import random
import collections


import transformers
import torch
import tqdm.auto
from torch import Tensor

In [3]:
def sinusoidal_encode(
    x: Tensor,
    embedding_dim: int,
    min_value: int,
    max_value: int,
    use_l2_norm: bool = False,
    norm_const: float | None = None,
) -> Tensor:
    """
    Encodes a tensor of numbers into a sinusoidal representation, inspired by how absolute positional
    encoding works in transformers.

    The encoding is an evaluation of a sine and cosine function at different frequencies, where the
    frequency is determined by the embedding dimension and the allowed range of the input values.

    >>> sinusoidal_encode(
    ...     torch.tensor([-5, 2, 1, 0]),
    ...     embedding_dim=6,
    ...     min_value=-5,
    ...     max_value=5,
    ... )
    tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
            [ 0.6570,  0.7539, -0.1073, -0.9942,  0.9980,  0.0627],
            [-0.2794,  0.9602,  0.3491, -0.9371,  0.9616,  0.2746],
            [-0.9589,  0.2837,  0.7317, -0.6816,  0.8806,  0.4738]])
    """

    if embedding_dim % 2 != 0 and not use_l2_norm:
        raise ValueError("Embedding dimension must be even")

    if use_l2_norm:
        if embedding_dim % 2 == 0:
            reserved_dim = 2
        else:
            reserved_dim = 1
        embedding_dim -= reserved_dim
    else:
        reserved_dim = 0  # will not be used

    domain = max_value - min_value
    y_shape = x.shape + (embedding_dim,)
    y = torch.zeros(y_shape, device=x.device)
    even_indices = torch.arange(0, embedding_dim, 2)
    log_term = torch.log(torch.tensor(domain)) / embedding_dim
    div_term = torch.exp(even_indices * -log_term)
    x = x - min_value
    values = x.unsqueeze(-1).float() * div_term
    y[..., 0::2] = torch.sin(values)
    y[..., 1::2] = torch.cos(values)

    if use_l2_norm:
        y = torch.cat([y, torch.ones_like(y[..., :reserved_dim])], dim=-1)
        y /= y.norm(dim=-1, keepdim=True, p=2)

    if norm_const is not None:
        y *= norm_const

    return y

def binary_encode(
    x: Tensor,
    embedding_dim: int,
    min_value: int | float,
    max_value: int | float,
    use_l2_norm: bool = False,
    norm_const: float | None = None,
) -> Tensor:
    y = torch.zeros(x.shape + (embedding_dim,), device=x.device)
    reserve_dim = 0 if not use_l2_norm else 1
    x = x - min_value
    maximum = x.max()
    for i in range(embedding_dim - reserve_dim):
        coeff = 2**i
        if maximum < coeff:
            break
        y[..., -i - 1] = torch.floor(x / coeff) % 2
        x = x - coeff * y[..., -i - 1]
    if use_l2_norm:
        y = torch.cat([y, torch.ones_like(y[..., :reserve_dim])], dim=-1)
        y /= y.norm(dim=-1, keepdim=True, p=2)
    if norm_const is not None:
        y *= norm_const
    return y

### Prepare model and data

In [4]:
model_ckpt = "meta-llama/Llama-3.2-1B"
model = transformers.AutoModel.from_pretrained(model_ckpt).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_ckpt)
model = model.half().to(device).eval()

In [5]:
all_values = torch.arange(0, 1000)
mask = torch.rand(len(all_values), generator=torch.Generator().manual_seed(0))
train_mask = mask < 0.9
valid_mask = ~train_mask & (mask < 0.95)
test_mask = ~train_mask & ~valid_mask

train_values = all_values[train_mask]
valid_values = all_values[valid_mask]
test_values = all_values[test_mask]

In [6]:
all_inputs = [(x1, x2) for x1, x2 in itertools.product(all_values.tolist(), repeat=2) if x1 + x2 < 1000]
train_values_set = set(train_values.tolist())
valid_values_set = set(valid_values.tolist())
test_values_set = set(test_values.tolist())
        
train_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in train_values_set]
valid_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in valid_values_set]
test_inputs = [(x1, x2) for x1, x2 in all_inputs if x2 in test_values_set]

# sanity check
assert set(train_inputs) & set(valid_inputs) == set()
assert set(train_inputs) & set(test_inputs) == set()
assert set(valid_inputs) & set(test_inputs) == set()

random.seed(0)
random.shuffle(train_inputs)
random.shuffle(valid_inputs)
random.shuffle(test_inputs)
valid_size = 4096
train_size = 100_000
train_inputs = train_inputs[:train_size]
valid_inputs = valid_inputs[:valid_size]

In [7]:
def make_str_input(operands: tuple[int, int] | list[int]) -> str:
    x1, x2 = operands
    return f"{x1} + {x2}"

make_str_input((3, 500)), make_str_input((3, 0))

('3 + 500', '3 + 0')

In [8]:
def get_hidden_states(model, str_inputs: list[str], batch_size: int) -> collections.defaultdict[int, Tensor]:
    model.eval()
    hidden_states = collections.defaultdict(list)
    with torch.no_grad():
        num_batches = (len(str_inputs) + batch_size - 1) // batch_size
        for batch_str in tqdm.auto.tqdm(itertools.batched(str_inputs, n=batch_size), total=num_batches):
            batch_inputs = tokenizer(batch_str, return_tensors="pt")
            hidden_reprs = model(**batch_inputs.to(model.device), output_hidden_states=True).hidden_states
            for layer_idx, hidden_state in enumerate(hidden_reprs):
                hidden_states[layer_idx].extend(hidden_state[:, -1, :].detach().cpu())
    return {k: torch.stack(v) for k, v in hidden_states.items()}

In [9]:
batch_size = 1024
train_hidden_states = get_hidden_states(model, [make_str_input(val) for val in train_inputs], batch_size)
valid_hidden_states = get_hidden_states(model, [make_str_input(val) for val in valid_inputs], batch_size)
test_hidden_states = get_hidden_states(model, [make_str_input(val) for val in test_inputs], batch_size)

  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

### Probing

In [10]:
basis_embs_sin = sinusoidal_encode(
    torch.arange(1000),
    min_value=0,
    max_value=1000,
    embedding_dim=train_hidden_states[0].shape[-1],
)

basis_embs_bin = binary_encode(
    torch.arange(1000),
    min_value=0,
    max_value=1000,
    embedding_dim=10,
)

In [11]:
class ClassifierProbe(torch.nn.Module):
    def __init__(self, emb_dim: int, hidden_dim: int, basis: torch.Tensor, heldout_mask: torch.Tensor):
        super().__init__()
        self.emb_to_latent = torch.nn.Linear(emb_dim, hidden_dim, bias=True)
        self.basis_to_latent = torch.nn.Linear(basis.shape[-1], hidden_dim, bias=True)
        self.basis: torch.nn.Buffer
        self.heldout_mask: torch.nn.Buffer
        self.register_buffer("basis", basis)
        self.register_buffer("heldout_mask", heldout_mask)
    def forward(self, x: Tensor, holdout_eval_tokens: bool) -> Tensor:
        latent_x = self.emb_to_latent(x)
        # during training, model learns to choose among only training tokens
        # but during eval, model must choose among all tokens
        # this means that the model is never exposed to the eval tokens during training
        latent_choices = self.basis_to_latent(self.basis)
        logits = latent_x @ latent_choices.T
        if holdout_eval_tokens:
            logits[:, self.heldout_mask] = float("-inf")
        return logits

In [12]:
train_labels = torch.tensor([x2 for x1, x2 in train_inputs])
valid_labels = torch.tensor([x2 for x1, x2 in valid_inputs]).to(device)
test_labels = torch.tensor([x2 for x1, x2 in test_inputs]).to(device) 

test_accuracies = {"sin": {}, "bin": {}, "lin": {}, "log": {}}

for basis_name, basis_embs in {"sin": basis_embs_sin, "bin": basis_embs_bin}.items():
    for layer_idx in range(len(train_hidden_states)):

        torch.manual_seed(0)
        probe = ClassifierProbe(
            emb_dim=train_hidden_states[0].shape[-1],
            hidden_dim=100,
            basis=basis_embs,
            heldout_mask=test_mask,
        ).to(device)

        optimizer = torch.optim.Adam(probe.parameters(), lr=1e-3)

        rng = torch.Generator().manual_seed(0)
        best_val_acc = -1
        best_ckpt = None
        for i in range(10000+1):
            probe.train()
            optimizer.zero_grad()
            minibatch_idcs = torch.randint(len(train_labels), size=(1024,), generator=rng)
            x = train_hidden_states[layer_idx][minibatch_idcs].float().to(device)
            y = train_labels[minibatch_idcs].to(device)
            logits = probe(x, holdout_eval_tokens=True)
            # add l1 regularization of all params to the loss
            loss = torch.nn.functional.cross_entropy(logits, y) + 0.001 * sum(p.abs().sum() for p in probe.parameters())
            loss.backward()
            optimizer.step()
            if i % 500 == 0:
                train_acc = (logits.argmax(dim=-1) == y).float().mean().item()
                probe.eval()
                with torch.no_grad():
                    valid_logits = probe(valid_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)
                    valid_loss = torch.nn.functional.cross_entropy(valid_logits, valid_labels)
                    valid_accuracy = (valid_logits.argmax(dim=-1) == valid_labels).float().mean().item()
                    if valid_accuracy > best_val_acc:
                        best_val_acc = valid_accuracy
                        best_ckpt = probe.state_dict()
                print(f"{basis_name} {i=:>5} train loss: {loss.item():5.2f}  train acc: {train_acc:.2f}  val loss: {valid_loss.item():5.2f}  valid acc: {valid_accuracy:.2f}")
        probe.load_state_dict(best_ckpt)
        probe.eval()
        with torch.no_grad():
            test_logits = probe(test_hidden_states[layer_idx].float().to(device), holdout_eval_tokens=False)
            test_accuracy = (test_logits.argmax(dim=-1) == test_labels).float().mean().item()
        test_accuracies[basis_name][layer_idx] = test_accuracy
        print(f"->  {basis_name}  layer idx: {layer_idx:<3}, best valid accuracy: {best_val_acc:.2f}, test accuracy: {test_accuracy:.2f}")
                    

sin i=    0 train loss: 11.39  train acc: 0.00  val loss:  6.72  valid acc: 0.00
sin i=  500 train loss:  2.23  train acc: 0.99  val loss:  1.36  valid acc: 0.83
sin i= 1000 train loss:  1.83  train acc: 0.99  val loss:  0.94  valid acc: 0.92
sin i= 1500 train loss:  1.66  train acc: 1.00  val loss:  0.78  valid acc: 0.88
sin i= 2000 train loss:  1.56  train acc: 0.99  val loss:  0.74  valid acc: 0.89
sin i= 2500 train loss:  1.50  train acc: 0.99  val loss:  0.73  valid acc: 0.84
sin i= 3000 train loss:  1.41  train acc: 1.00  val loss:  0.71  valid acc: 0.84
sin i= 3500 train loss:  1.35  train acc: 1.00  val loss:  0.71  valid acc: 0.82
sin i= 4000 train loss:  1.32  train acc: 1.00  val loss:  0.71  valid acc: 0.82
sin i= 4500 train loss:  1.28  train acc: 1.00  val loss:  0.72  valid acc: 0.80
sin i= 5000 train loss:  1.26  train acc: 1.00  val loss:  0.68  valid acc: 0.80
sin i= 5500 train loss:  1.22  train acc: 1.00  val loss:  0.69  valid acc: 0.82
sin i= 6000 train loss:  1.1

In [13]:
def solve_linear_layer(x: Tensor, y: Tensor) -> torch.nn.Linear:
    if y.ndim == 1:
        y = y.unsqueeze(-1)
    if not y.is_floating_point():
        y = y.float()
   
    lin = torch.nn.Linear(x.shape[-1], y.shape[-1], device=x.device)
    x_aug = torch.cat([x, torch.ones(len(x), 1, device=x.device)], dim=1)
    coeffs = torch.linalg.lstsq(x_aug, y).solution
    w, b = coeffs[:-1], coeffs[-1]
    with torch.no_grad():
        lin.weight[:] = w.T
        lin.bias[:] = b
    return lin

In [24]:
for layer_idx in range(len(train_hidden_states)):
    lin_probe = solve_linear_layer(
        train_hidden_states[layer_idx].float().to(device),
        train_labels.to(device),
    )
    log_probe = solve_linear_layer(
        train_hidden_states[layer_idx].float().to(device),
        train_labels.log1p().to(device),
    )
    lin_test_pred = lin_probe(test_hidden_states[layer_idx].float().to(device)).flatten().round().int()
    lin_test_accuracy = (lin_test_pred == test_labels).float().mean().item()
    
    log_test_pred = log_probe(test_hidden_states[layer_idx].float().to(device)).flatten().exp().add(1).round().int()
    log_test_accuracy = (log_test_pred == test_labels).float().mean().item()
    
    test_accuracies["lin"][layer_idx] = lin_test_accuracy
    test_accuracies["log"][layer_idx] = log_test_accuracy

    print(f"layer idx: {layer_idx:<3}, linear probe acc: {lin_test_accuracy:.2f}, log probe acc: {log_test_accuracy:.2f}")

layer idx: 0  , linear probe acc: 0.02, log probe acc: 0.00
layer idx: 1  , linear probe acc: 0.01, log probe acc: 0.00
layer idx: 2  , linear probe acc: 0.02, log probe acc: 0.02
layer idx: 3  , linear probe acc: 0.03, log probe acc: 0.03
layer idx: 4  , linear probe acc: 0.02, log probe acc: 0.03
layer idx: 5  , linear probe acc: 0.02, log probe acc: 0.03
layer idx: 6  , linear probe acc: 0.02, log probe acc: 0.04
layer idx: 7  , linear probe acc: 0.02, log probe acc: 0.04
layer idx: 8  , linear probe acc: 0.02, log probe acc: 0.05
layer idx: 9  , linear probe acc: 0.02, log probe acc: 0.04
layer idx: 10 , linear probe acc: 0.03, log probe acc: 0.05
layer idx: 11 , linear probe acc: 0.02, log probe acc: 0.03
layer idx: 12 , linear probe acc: 0.02, log probe acc: 0.03
layer idx: 13 , linear probe acc: 0.02, log probe acc: 0.03
layer idx: 14 , linear probe acc: 0.03, log probe acc: 0.05
layer idx: 15 , linear probe acc: 0.02, log probe acc: 0.04
layer idx: 16 , linear probe acc: 0.02, 

In [25]:
for name, accs in test_accuracies.items():
    print(f"{name} accs: | " + " | ".join([f"{x:.0%}" for x in accs.values()]) + " |")

sin accs: | 89% | 96% | 100% | 99% | 100% | 100% | 100% | 99% | 98% | 99% | 100% | 99% | 99% | 99% | 96% | 99% | 86% |
bin accs: | 41% | 15% | 12% | 23% | 16% | 12% | 8% | 9% | 9% | 9% | 12% | 11% | 9% | 7% | 11% | 14% | 11% |
lin accs: | 2% | 1% | 2% | 3% | 2% | 2% | 2% | 2% | 2% | 2% | 3% | 2% | 2% | 2% | 3% | 2% | 2% |
log accs: | 0% | 0% | 2% | 3% | 3% | 3% | 4% | 4% | 5% | 4% | 5% | 3% | 3% | 3% | 5% | 4% | 4% |
