In [51]:
# Some problems with this approach:
# - Simpler vocab !=? simpler grammar
# - May want simpler words to have a higher probability, 
#     rather than just removing all words.
# Idea:
# what if we blend the output probabilities
# with the probabilities that arise in a corpus.
# e.g. the probabilities that jk rowling uses, etc.
# At that point might as well finetune, 
# but would be interesting to see what happens/compare/benchmark.
# Idea:
# Might be fun to do a grid search/random search, over .generate() parameters,
# and compare the average perplexity (averaged over N generation attempts per parameter set)

In [52]:
import json
import pathlib
import re

import numpy as np
import spacy
import torch
from tqdm import tqdm
from transformers import GPT2Tokenizer, GPTNeoForCausalLM


In [53]:
if not globals().get("model"):
    SMALL_MODEL = False  # use smaller model for testing
    MODEL = "EleutherAI/gpt-neo-125M" if SMALL_MODEL else "EleutherAI/gpt-neo-1.3B"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPTNeoForCausalLM.from_pretrained(MODEL).half().to(device)
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL)
    # Keep a reference to the original lm_head
    original_lm_head = model.lm_head
print(f"{device = }")
print(f"{MODEL = }")

device = device(type='cuda')
MODEL = 'EleutherAI/gpt-neo-1.3B'


In [54]:
class MaskedLMHead(torch.nn.Module):
    """This class is a wrapper around the language model head of the gpt neo model.
    It is used to mask the logits of the language model head,
    to restrict the vocabulary of the model.
    """

    def __init__(self, lm_head, mask):
        super().__init__()
        self.lm_head = lm_head
        self.mask = mask

    def forward(self, *args, **kwargs):
        logits = self.lm_head(*args, **kwargs)
        masked_logits = logits.masked_fill(self.mask, float("-inf"))
        return masked_logits


# Create a mask that makes all tokens except "ok"
# have a predicted probability of zero
tokenizer_dict = tokenizer.get_vocab()
mask = torch.ones(50257, device=device, dtype=torch.bool)
idx_of_ok = tokenizer_dict["ok"]
mask[idx_of_ok] = 0  # 0 means don't mask

# Replace the model's lm_head with our masked version
print(f"Original head: {original_lm_head = }")
model.lm_head = MaskedLMHead(original_lm_head, mask)
print(f"Masked head: {model.lm_head = }")

Original head: original_lm_head = Linear(in_features=2048, out_features=50257, bias=False)
Masked head: model.lm_head = MaskedLMHead(
  (lm_head): Linear(in_features=2048, out_features=50257, bias=False)
)


In [55]:
# Run the model on a dummy input and
# confirm that all tokens except for "ok"
# get a probability of 0
prompt = "There are five pastries on the"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
attention_mask = torch.zeros_like(input_ids)
logits = model(input_ids, attention_mask=attention_mask).logits
probabilities = torch.softmax(logits, dim=-1)
print(f"{probabilities = }")
print(f"{probabilities[..., idx_of_ok] = }")

probabilities = tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float16, grad_fn=<SoftmaxBackward0>)
probabilities[..., idx_of_ok] = tensor([[1., 1., 1., 1., 1., 1., 1.]], device='cuda:0', dtype=torch.float16,
       grad_fn=<SelectBackward0>)


In [56]:
# See what happens when we try to generate text,
# using our model with the masked language model head
prompt = "Everything is going to be "
tokens = tokenizer(prompt, return_tensors="pt").to(device)
gen_tokens = model.generate(
    **tokens,
    do_sample=True,
    temperature=0.9,
    max_new_tokens=10,
    pad_token_id=tokenizer.eos_token_id,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

Everything is going to be okokokokokokokokokok


In [57]:
# Create a constrained vocabulary,
# by using only the words (/tokens) found in our corpus dir.
# (We'll load it from a file if it exists, otherwise we'll create it.)

CORPUS_DIR = "vocab-corpus" #"/media/sid/bigdata/datasets/harry-potter-text"
VOCAB_PATH = "vocabulary.json"

nlp = spacy.load("en_core_web_sm")


def create_vocabulary(corpus_dir):
    corpus_files = pathlib.Path(CORPUS_DIR).glob("*.txt")
    words = set()
    for i, corpus_file in enumerate(corpus_files):
        lines = re.sub(r"\n+", "\n", corpus_file.read_text()).split("\n")
        # Process the text 200 lines at a time,
        # to avoid memory issues.
        for i in tqdm(range(0, len(lines), 200), desc=f"File {i}"):
            text = "\n".join(lines[i : i + 200])
            words.update([token.text for token in nlp(text)])
    words = set(word.strip(" ") for word in words)
    # Make sure each word has a the capitalisation: word, Word
    for word in words.copy():
        words.update([word.lower(), word.capitalize()])
    # Include each word with spaces around it
    for word in words.copy():
        words.update([f" {word} ", f" {word}", f"{word} "])
    words.update([" "])
    return sorted(words)


if not pathlib.Path(VOCAB_PATH).exists():
    print("Creating vocabulary...")
    vocabulary = create_vocabulary(CORPUS_DIR)
    with open(VOCAB_PATH, "w") as f:
        json.dump(vocabulary, f, indent=2)
else:
    with open(VOCAB_PATH) as f:
        vocabulary = json.load(f)
    print(f"Loaded vocabulary from {VOCAB_PATH}.")


# Tokenize this vocabulary with our tokenizer
allowed_token_indices = set()
for word in tqdm(vocabulary):
    token_indices = tokenizer.encode(word, add_special_tokens=False)
    allowed_token_indices.update(token_indices)

Loaded vocabulary from vocabulary.json.


100%|██████████| 429/429 [00:00<00:00, 13271.74it/s]


In [58]:
# Create a mask that makes all tokens except
# the ones in our vocabulary have a predicted probability of zero
mask = torch.ones(50257, device=device, dtype=torch.bool)
mask[list(allowed_token_indices)] = 0  # 0 means don't mask
# Use this in our MaskedLMHead for the language model
model.lm_head = MaskedLMHead(original_lm_head, mask)

prompt = """
Little Jane ran up the lane
To hang her clothes a-drying;
She called for Nell to ring the bell,
For Jack and Jill were dying.
Nimble Dick ran up so quick,
He tumbled over a timber,
And bent his bow to shoot a crow,
And killed a cat in the window. 
""".strip()
input_encodings = tokenizer(prompt, return_tensors="pt").to(device)
    
# See what happens when we try to generate text,
# using our model with the masked language model head
generated_texts = []
for i in tqdm(range(40)):
    torch.manual_seed(i)
    np.random.seed(i)
    gen_tokens = model.generate(
        **input_encodings,
        max_new_tokens=100,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        # temperature=0.9,
        # num_beams=5,
        # typical_p=0.7,
        repetition_penalty=10.0,
    )
    gen_text = tokenizer.batch_decode(gen_tokens)[0]
    generated_texts.append(gen_text)
    # seperate the prompt from the generated text

100%|██████████| 40/40 [01:21<00:00,  2.03s/it]


In [59]:
# Compute perplexity, to estimate the quality of the results.
# Restore the models original lm_head for this
# references: 
# https://stackoverflow.com/a/61990477
# https://huggingface.co/docs/transformers/perplexity
model.lm_head = original_lm_head
start_loc = input_encodings.input_ids.shape[-1]
perplexities = []
for generated_text in generated_texts:
    encodings = tokenizer(generated_text, return_tensors="pt").to(device)
    perplexity_values = []
    for i in tqdm(range(start_loc, encodings.input_ids.shape[-1])):
        input_ids = encodings.input_ids[:, :i]
        label = encodings.input_ids[:, i]
        with torch.no_grad():
            # next token prediction
            outputs = model(input_ids)
            logits = outputs.logits
            p_next = logits[0, -1, :].softmax(0)[label]
            # perplexity
            perplexity = torch.log(p_next)
            perplexity_values.append(perplexity.item())
    perplexities.append(np.exp(-np.mean(perplexity_values)))
print(f"{perplexities = }")

100%|██████████| 100/100 [00:02<00:00, 36.85it/s]
100%|██████████| 97/97 [00:02<00:00, 36.31it/s]
100%|██████████| 100/100 [00:02<00:00, 35.97it/s]
100%|██████████| 99/99 [00:02<00:00, 36.43it/s]
100%|██████████| 97/97 [00:02<00:00, 36.53it/s]
100%|██████████| 96/96 [00:02<00:00, 36.55it/s]
100%|██████████| 100/100 [00:02<00:00, 36.38it/s]
100%|██████████| 98/98 [00:02<00:00, 36.61it/s]
100%|██████████| 98/98 [00:02<00:00, 36.46it/s]
100%|██████████| 98/98 [00:02<00:00, 36.09it/s]
100%|██████████| 100/100 [00:02<00:00, 36.61it/s]
100%|██████████| 98/98 [00:02<00:00, 36.49it/s]
100%|██████████| 100/100 [00:02<00:00, 36.47it/s]
100%|██████████| 100/100 [00:02<00:00, 36.22it/s]
100%|██████████| 99/99 [00:02<00:00, 36.49it/s]
100%|██████████| 100/100 [00:02<00:00, 36.89it/s]
100%|██████████| 99/99 [00:02<00:00, 36.54it/s]
100%|██████████| 99/99 [00:02<00:00, 37.03it/s]
100%|██████████| 99/99 [00:02<00:00, 36.53it/s]
100%|██████████| 98/98 [00:02<00:00, 36.23it/s]
100%|██████████| 100/100 [

perplexities = [22.796335038836506, 64.54056506488412, 1.2567036303927295, 249.41861558356214, 59.85650366122411, 142.55414367216824, 28.50182903034601, 148.45406628006202, 87.27959976708215, 40.577639982269325, 50.214647283547016, 154.0925203251954, 2.0459603939989828, 78.00486840841882, 42.76320028940462, 54.6839393571641, 44.50173344079129, 20.49421244108672, 69.71377051421017, 70.67250233269793, 2.49397815788712, 57.81493442348986, 85.8760179147562, 68.99754698737738, 85.80091209609346, 106.38654954769524, 49.960863951113915, 44.962026701591256, 70.26894960386151, 79.13663857308812, 141.22372212342958, 15.678126094907935, 24.71580520660862, 218.66673480676752, 117.55381838115538, 136.6527587061457, 3.6199115251073994, 79.89495688750674, 41.882079596977306, 60.22100597449658]





In [60]:
# dump results to a text file
with open("results.txt", "w") as f:
    f.write("""

--------------------------------------------------

""".join(f"{perplexity}\n{result}" for perplexity, result in sorted(zip(perplexities, generated_texts))))

# Hmm, seems like perplexity isn't that useful, 
# the results with low perplexity seem to often have
# a lot of repetition, and the results with high perplexity
# seem to be more unique, but can also lack coherence.

In [61]:
# for debugging
for id in gen_tokens[0].tolist():
    print(tokenizer.decode(id), end="|")

Little| Jane| ran| up| the| lane|
|To| hang| her| clothes| a|-|d|rying|;|
|She| called| for| N|ell| to| ring| the| bell|,|
|For| Jack| and| Jill| were| dying|.|
|N|imble| Dick| ran| up| so| quick|,|
|He| t|umbled| over| a| timber|,|
|And| bent| his| bow| to| shoot| a| crow|,|
|And| killed| a| cat| in| the| window|.|�|�|
|
|The| p|he|as|ent| of| j|in|p|ke|w| came| with| a| g|ail|er| after| the| g|ill|f|rot|,| as|
|the| T|ame| could| see| he| did| she|w| The|s| F|f|her|,| In| the| Her|d| of| Gr|t|as|,|
|M|m|end|,| For| he| went| to| the| p|he|as|ent| of| t|j|m|end|.|
|
|T|t|gr|t|as| got| down| to| t|j|t|and|g|to|of|,| To| see| the| her|d| of|
|J|