About
-------

Just some experiments with so-called "glitch tokens" in transformer models. The code below is adapted from [this blogpost](https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation).

In [1]:
import json

import torch
from transformers import (
    AutoTokenizer, GPT2Tokenizer, GPTJForCausalLM, GPT2LMHeadModel
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("glitchtokens.json", 'r') as fin:
    glitch = json.load(fin)

In [3]:
name = "gpt-j"

if 'gpt-j' in name:
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
else:
    tokenizer = GPT2Tokenizer.from_pretrained(name, padding_side='left')
    
if 'gpt-j' in name:
    model = GPTJForCausalLM.from_pretrained(
        "EleutherAI/gpt-j-6B"
        , revision="float16"
        , torch_dtype=torch.float16
        , low_cpu_mem_usage=True
    )
else:
    model = GPT2LMHeadModel.from_pretrained(
        name
        , pad_token_id=tokenizer.eos_token_id
    )

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
model.eval()
model.to(device)
embeddings = model.transformer.wte.weight.detach()
embeddings = embeddings / torch.sqrt(torch.sum(embeddings**2, dim=-1, keepdim=True))

In [6]:
def get_dists(emb, word_embeddings):
    torch.cuda.empty_cache()
    dists = 1 - (emb.unsqueeze(0) @ word_embeddings.T).squeeze(0)
    sorted_dists, ix = torch.sort(dists)
    
    return sorted_dists, ix

def closest_tokens(emb, word_embeddings, tokenizer, n=1):
    sorted_dists, ix = get_dists(emb, word_embeddings)

    tokens = [tokenizer.decode(i) for i in ix[:n]]
    ixs = ix[:n]
    dists = sorted_dists[:n]
    embs = word_embeddings[ixs]
    
    return tokens, ixs, dists, embs

def most_distant_tokens(emb, word_embeddings, tokenizer, n=1):
    sorted_dists, ix = get_dists(emb, word_embeddings)

    tokens = [tokenizer.decode(i) for i in ix[-n:]]
    ixs = ix[-n:]
    dists = sorted_dists[-n:]
    embs = word_embeddings[ixs]
    
    # Reverse order for readability: most distant token is first
    # in the returns
    tokens.reverse()
    ixs, dists, embs = (
        torch.flip(ixs, dims=(0,))
        , torch.flip(dists, dims=(0,))
        , torch.flip(embs, dims=(0,))
    )
    
    return tokens, ixs, dists, embs

def query_by_word(word, embeddings, tokenizer, n=1, func=closest_tokens):
    ix ,= tokenizer.encode(word)
    emb = embeddings[ix]
    
    return func(emb, embeddings, tokenizer, n)

In [7]:
# Center of all the embeddings

centroid = embeddings.mean(dim=0) / torch.sqrt(torch.sum(embeddings.mean(dim=0)**2, dim=-1, keepdim=True))

In [8]:
for tok in glitch['tokens']:
    toks, ixs, dists, _ = query_by_word(
        tok
        , embeddings
        , tokenizer
        , n=5
        , func=most_distant_tokens
    )
    print(f"tok: {tok} | idx: {tokenizer.encode(tok)[0]}")
    print(', '.join(toks), "\n")

tok:   | idx: 188
 �, cheat, kids,  prob, utm 

tok:  | idx: 189
 �, kids, gif,  Kindle, cheat 

tok:  | idx: 190
 �, cheat, wikipedia, prototype, caps 

tok:  | idx: 191
 �, gif, �, kids, � 

tok:  | idx: 192
 �, utm, kids, DW, cler 

tok:  | idx: 193
 �, DW, kids, bsp, cheat 

tok:  | idx: 194
 �,  ›, caps, utm, kids 

tok:  | idx: 195
 �,  prob, caps, cheat,  � 

tok:  | idx: 196
 �, kids, cheat, HOU,  LET 

tok:  | idx: 202
 �, caps, kids,  Kavanaugh,  Sard 

tok:  | idx: 203
 �, �, �, �, utm 

tok:  | idx: 204
 �, �, gif, �,  � 

tok:  | idx: 205
 �, �, wikipedia, �, ソ 

tok:  | idx: 206
 �, Â, caps, gif, cheat 

tok:  | idx: 207
 �, caps, gif, ._,  � 

tok:  | idx: 208
 �, utm,  ›, gif, ._ 

tok:  | idx: 209
 �, gif, utm, ._,  › 

tok:  | idx: 210
 �, gif, utm, ��, ._ 

tok:  | idx: 211
 �, ._, gif,  ›, caps 

tok:  | idx: 212
 �, kids, ._, utm, gif 

tok:  | idx: 213
 �, ._, gif,  �, kids 

tok:  | idx: 214
 �, gif, ._, utm,  › 

tok:  | idx: 215
 �, kids, 

In [9]:
toks, *_ = closest_tokens(centroid, embeddings, tokenizer, n=25)
toks

['�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 '�',
 ' davidjl',
 'PsyNetMessage',
 'embedreportprint',
 ' RandomRedditor',
 ' RandomRedditorWithNo',
 'InstoreAndOnline',
 ' TheNitrome',
 ' TheNitromeFan',
 'GoldMagikarp',
 ' srfN',
 ' SolidGoldMagikarp',
 ' Adinida',
 ' DevOnline',
 '<|extratoken_1|>',
 '<|extratoken_2|>',
 '<|extratoken_3|>',
 '<|extratoken_4|>']

In [10]:
# Note that certain glitch tokens are fairly likely to occur in the top-n
# neighbors of many tokens. Just how often? See below

query_by_word('Book', embeddings, tokenizer, n=10, func=closest_tokens)

(['Book',
  ' Book',
  ' book',
  'book',
  '<|extratoken_30|>',
  '�',
  ' externalToEVAOnly',
  'embedreportprint',
  ' RandomRedditor',
  'GoldMagikarp'],
 tensor([10482,  4897,  1492,  2070, 50286,   184, 30213, 30898, 36173, 42202],
        device='cuda:0'),
 tensor([0.0000, 0.1167, 0.1323, 0.1338, 0.1401, 0.1406, 0.1406, 0.1406, 0.1406,
         0.1406], device='cuda:0', dtype=torch.float16),
 tensor([[ 6.2180e-03,  9.6464e-04, -1.1539e-03,  ..., -1.0269e-02,
           2.0859e-02,  6.6299e-03],
         [ 1.2955e-02, -1.5671e-02, -2.5997e-03,  ...,  4.3869e-03,
           2.1912e-02, -4.0970e-03],
         [ 7.3967e-03, -1.2388e-03, -1.2657e-02,  ...,  1.3132e-03,
           1.0185e-02, -1.3130e-02],
         ...,
         [ 1.0529e-03, -1.3232e-04,  9.7418e-04,  ...,  4.0865e-04,
           1.7452e-03, -3.6299e-05],
         [ 1.8482e-03,  5.7888e-04,  1.2407e-03,  ...,  4.5824e-04,
           1.3008e-03, -4.5419e-04],
         [ 1.3437e-03,  1.4555e-04,  1.3247e-03,  ...,  3.4

In [11]:
samples = torch.randint(len(embeddings), (25_000,))

In [12]:
mean_dists = []
glitch_counts = []
for samp in samples:
    emb = embeddings[samp.item()]
    toks, ixs, dists, _ = closest_tokens(emb, embeddings, tokenizer, n=10)
    glitch_counts.append(sum(1 for tok in toks if tok in glitch['tokens']) / 10)
    mean_dists.append(dists[1:].mean().cpu())

In [13]:
# Mean distance 

torch.tensor(mean_dists).mean().item()

0.134521484375

In [14]:
# % of the top-10 tokens in the sample that are glitch tokens. As one
# comment in the cited blogpost states, what this may indicate is 
# something called "hubness," which is a general phenomenon in 
# high-dimensional feature space
#
# Comment: https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation#comments

torch.tensor(glitch_counts).mean().item()

0.35468000173568726