I had the idea of using a pretrained embedding layer, inverted, as the first layer of a decoder. Forcing the latent space to use it's structure. Ideally it's a compressed but human interpretable latent space. If it works it should help probes. first lets prototype the embedding

see also
- https://keras.io/api/keras_nlp/modeling_layers/reversible_embedding/

In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from typing import Optional, List, Dict, Union
from jaxtyping import Float
from torch import Tensor

import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
from einops import rearrange

from loguru import logger
import lightning.pytorch as pl

logger.add(os.sys.stderr, format="{time} {level} {message}", level="INFO")

# load my code
%load_ext autoreload
%autoreload 2

from src.llms.load import load_model

import seaborn as sns
sns.set_theme("paper")

plt.rcParams["figure.figsize"] = (4, 3)

In [2]:
model, tokenizer = load_model(
    # cfg.model,
    device='cpu',
    bnb=False,
    trust_remote_code=True,
    # model_class=PhiForCausalLMWHS, # ti add hidden states
    # bnb=False,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# get embedding layer
embedding = model.model.embed_tokens.eval().cpu()
embedding

Embedding(51200, 2560)

In [4]:
vocab_size, hidden_dim = embedding.weight.shape

In [5]:
batch_size = 2
# vocab_size = 100
# hidden_dim = 32
seq_length = 50

# Generate random inputs.
token_ids = torch.tensor(np.random.randint(vocab_size, size=(batch_size, seq_length)))

# embedding = keras_nlp.layers.ReversibleEmbedding(vocab_size, hidden_dim)

# Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
hidden_states = embedding(token_ids)

# Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
# logits = embedding(hidden_states, reverse=True)

# compare
# token_ids==logist.argmax(-1)

In [6]:
# F.embedding(token_ids, embedding.weight).shape
# F.embedding??

In [7]:
# torch.embedding(embedding.weight, token_ids).shape

In [45]:
class ReversibleEmbedding(nn.Module):
    """Embedding layer with reversible method."""
    
    def __init__(self, vocab_size, hidden_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(vocab_size, hidden_dim))
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
    
    def forward(self, x, reverse=False):
        """Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.
        Project hidden states to shape `(batch_size, seq_length, vocab_size)`.
        """
        # note we skip the one-hot encoding step, and the argmax(-1) reverse step
        W = self.weight
        # TODO: norm and unnorm
        if not reverse:
            return (x @ W)
        else:
            # return 2 * torch.matmul(x, W.T)
            # return torch.sum(x ** 2, dim=-1, keepdim=True) + torch.sum(W**2, dim=-1) - 2 *  x @ W.T
            return x @ W.T
        
        
        
m = ReversibleEmbedding(vocab_size, hidden_dim)


In [46]:
# torch.linalg.inv(m.weight[None]) 
# F.normalize(m.weight, dim=-1)
 

In [47]:
# check tokens go through
x = F.one_hot(token_ids, vocab_size).float().detach()
y = m(x)
x2 = m(y, reverse=True).detach()
# x2 = x2 / x2.sum(-1, keepdim=True)
token_ids2 = x2.argmax(-1)
print(((x==x2.round())*1.0).mean())
print(((token_ids2==token_ids)*1.0).mean())
x.shape, y.shape, x2.shape

tensor(0.0079)
tensor(1.)


(torch.Size([2, 50, 51200]),
 torch.Size([2, 50, 2560]),
 torch.Size([2, 50, 51200]))

In [52]:
# check random vector goes through
z = torch.randn_like(x).detach().abs()
z2 = m(z).detach()
z3 = m(z2, reverse=True).detach()
print(z.shape, z2.shape, z3.shape)

# z3 = z3.abs()
# z3 = z3 / z3.abs().sum(-1, keepdim=True)
# z-z3

torch.Size([2, 50, 51200]) torch.Size([2, 50, 2560]) torch.Size([2, 50, 51200])


In [53]:
z[0]

tensor([[1.1103, 0.2669, 0.5260,  ..., 1.0474, 1.8734, 0.9161],
        [0.0934, 0.3709, 0.6267,  ..., 0.1941, 1.0740, 0.4241],
        [1.9605, 0.2693, 0.8677,  ..., 0.2919, 3.0014, 0.4272],
        ...,
        [1.0976, 0.6049, 1.3475,  ..., 0.0136, 0.5689, 0.4565],
        [0.2915, 1.4355, 0.0290,  ..., 1.3238, 0.2580, 0.9787],
        [0.2953, 1.1903, 0.8469,  ..., 0.4034, 0.6518, 0.6749]])

In [54]:
z3[0]

tensor([[ 16546.8418,  14871.1045,  15367.9600,  ...,  14620.1445,
          -2121.3865,  22985.0586],
        [ 27663.6582,  -9911.5547,   7147.8276,  ...,   9748.4248,
           7331.7178,   6365.5449],
        [ 17518.9160,  15047.8984,  -5408.0283,  ...,  -7657.2080,
          23668.7109,  10854.2041],
        ...,
        [ 22566.5449,   8952.7881,   4303.7622,  ...,  -9909.8447,
          -3576.2451,  19538.4238],
        [ 16341.0459,   5855.2935,   4256.1382,  ...,  17216.6562,
         -16308.8965,  11373.1865],
        [ 17308.6270,  23979.8477,  20052.1953,  ...,   6870.7231,
           4182.9902,  10143.3984]])

: 

In [51]:
np.testing.assert_allclose(z, z3)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 5120000 / 5120000 (100%)
Max absolute difference: 60478.492
Max relative difference: 139.55426
 x: array([[[-0.916023, -1.034765,  0.571842, ...,  1.658556, -0.036739,
         -0.339379],
        [-0.379376,  0.058935, -0.989116, ...,  0.105902, -0.734908,...
 y: array([[[-19673.627  ,    597.6177 , -11513.899  , ..., -13899.47   ,
            691.8318 ,  -4066.745  ],
        [-27098.701  ,  -9599.435  ,  -6322.4517 , ...,  -5654.139  ,...

In [None]:
# TODO invert

In [None]:
# TODO interpret a random vector, when using it inverted...

In [None]:
dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())