In [10]:
import torch
import torch.nn as nn
from datasets import load_from_disk
from tokenizers import Tokenizer

from chat_template import chat_template
from model import ChatModel
from settings import ModelSettings
from special_tokens import special_tokens

tokenizer = Tokenizer.from_file("tokenizer.json")
device = "cpu"

def apply_top_p(logits: torch.Tensor, p: float = 0.9, min_keep: int = 1) -> torch.Tensor:
    """
    In-place top-p filtering on logits
    Very common pattern in vLLM, transformers, llama.cpp python bindings, etc.
    """
    if p >= 1.0:
        return logits

    sorted_logits, sorted_idx = logits.sort(descending=True)
    sorted_probs = sorted_logits.softmax(dim=-1)
    cum_probs = sorted_probs.cumsum(dim=-1)

    # Mask tokens after the top-p cutoff
    mask = cum_probs > p
    # Protect at least min_keep tokens
    mask[..., :min_keep] = False

    # Scatter back
    to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
        dim=-1, index=sorted_idx, src=mask
    )

    logits[to_remove] = -float('inf')
    return logits

@torch.no_grad()
def generate(model, start, max_new_tokens=50, temperature=0.7, top_k:int|None=None, argmax=False, stop_tokens=None, top_p=0.9):
    idx = torch.tensor(
        [[tokenizer.token_to_id(special_tokens["bos"])]+ tokenizer.encode(start, add_special_tokens=False).ids],
        device=device,
        dtype=torch.long,
    )
    print(list(map(tokenizer.id_to_token,idx[0].tolist())))

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -ModelSettings.max_context_length:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, top_pos = torch.topk(logits, top_k)
            logits = torch.where(
                logits < top_logits[:, -1],
                input=torch.tensor(float("-inf")),
                other=logits
            )
        logits=apply_top_p(logits,top_p)

        probs = nn.functional.softmax(logits / temperature, dim=-1)

        if argmax:
            next_id = torch.argmax(probs, dim=-1, keepdim=True)
        else:
            next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

        if stop_tokens and next_id.item() in stop_tokens:
            print("Reached stop token")
            break

    return tokenizer.decode(idx[0].tolist())

In [11]:
minified = False

if not minified:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=ModelSettings.embedding_size,
        max_context_length=ModelSettings.max_context_length,
        ff_size_multiplier=ModelSettings.ff_size_multiplier,
        transformer_blocks=ModelSettings.transformer_blocks,
        attention_heads=ModelSettings.attention_heads,
        dropout=ModelSettings.dropout,
        bias=ModelSettings.bias,
        device=device,
    )
else:
    model = ChatModel(
        vocabulary_size=ModelSettings.vocabulary_size,
        embedding_size=64,
        max_context_length=64,
        ff_size_multiplier=2,
        transformer_blocks=4,
        attention_heads=4,
        dropout=ModelSettings.dropout,
        bias=ModelSettings.bias,
        device=device,
    )

using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention
using flash attention


In [12]:
step = 5999
state = torch.load(f"instruction_output_colab/{step:05d}_with_dropout.pt", map_location=torch.device('cpu'))

In [13]:
state_dict = state["model"]
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

In [14]:
model.load_state_dict(state["model"])

<All keys matched successfully>

In [15]:
stop_tokens = tokenizer.encode(special_tokens["end_of_turn"] + special_tokens["eos"], add_special_tokens=False).ids
print(stop_tokens)

[1, 3]


In [16]:
def chat_completion(user_text):
    return generate(
        model,
        chat_template([
            {"role": "user", "content": user_text}
        ],add_generation_token=True),
        max_new_tokens=200,
        stop_tokens=stop_tokens,
    )

In [17]:
data_dir = "tokenized_data"
test_ds_name = data_dir + "/test_chats"
ds_train = load_from_disk(test_ds_name).take(10).shuffle(0)
token_ids=next(iter(ds_train))["tokens"]
tokenizer.decode(token_ids)

'Write one paragraph about the Tlachihualtepetl pyramid. Keep it at or below 15 sentences. Include its size and when it was built. Also, include any aspects of worship it had. Include aspects of its more recent history, too. Name at least one specific year. \nWhen people hear the word "pyramid," they often think of Egypt. However, the world\'s largest pyramid is in Cholula, Mexico. Tlachihualtepetl, also called the Great Pyramid of Cholula, is 217 feet (66 meters) tall and its base is 1,476 by 1,1476 feet (450 by 450 meters). Its total volume is 157 million cubic feet. Due to its size, it took centuries to build. It was worked on several different times between 300 BC and 700 AD. There was a temple above it to worship Quetzalcoatl - the Aztec god of learning and books. However, the temple was abandoned around the 8th or 9th century AD. Over time, it became covered with vegetation. As a result, it blended in and Spaniards first mistook it for a hill when they conquered Mexico in 1521. T

In [18]:
# test
chat_completion("Write one paragraph about the Tlachihualtepetl pyramid.")

['<|bos|>', '<|user|>', 'W', 'rite', 'Ġone', 'Ġparagraph', 'Ġabout', 'Ġthe', 'ĠT', 'l', 'ach', 'ih', 'ual', 't', 'ep', 'et', 'l', 'Ġpyramid', '.', '<|endofturn|>', 'Ċ', '<|assistant|>']


"Write one paragraph about the Tlachihualtepetl pyramid.\nIf the Tlachihric What You First\n\nIt's real that Tlachihualtepret's natural Tlii. It's one of the most popular tlists in the world, but it's a secret one. It's a type of art that, which comes from a small type of colors that can be used by simple colors. It is very interesting to add some of the most popular Tlihualt, which is good breeds, and it's an interesting example to the tlachihualtepetl. Tlachihualtepetl is a type of art that evokes the basics of the world's art, making a feature that is truly unique and unique. It's a unique and unique style that can be used to add a touch to the artist's style. It's also a type of art that can add a touch to a touch of boldness or a touch, so make sure that it's"

In [19]:
# train
chat_completion("I'm trying to create a menu with different kinds of pasta. Help me come up with different types of pasta and what they are best used for.")

['<|bos|>', '<|user|>', 'I', "'m", 'Ġtrying', 'Ġto', 'Ġcreate', 'Ġa', 'Ġmenu', 'Ġwith', 'Ġdifferent', 'Ġkinds', 'Ġof', 'Ġpasta', '.', 'ĠHelp', 'Ġme', 'Ġcome', 'Ġup', 'Ġwith', 'Ġdifferent', 'Ġtypes', 'Ġof', 'Ġpasta', 'Ġand', 'Ġwhat', 'Ġthey', 'Ġare', 'Ġbest', 'Ġused', 'Ġfor', '.', '<|endofturn|>', 'Ċ', '<|assistant|>']
Reached stop token


"I'm trying to create a menu with different kinds of pasta. Help me come up with different types of pasta and what they are best used for.\nHere are a list of pasta!\n\n1. Cat-fil-fil-fil-fil- sauce\n2. Potato sauce\n3. Butter-fil-fil-fil- sauce\n4. Parmesanagine-fil-fil-fil- sauce\n5. Cauliflower-fil-fil- sauce\n6. Parmesan sauce\n7. Parmesan sauce\n8. Parmesan sauce\n9. Parmesan sauce"

In [20]:
# train
chat_completion("Write me an acrostic poem about cheese.")

['<|bos|>', '<|user|>', 'W', 'rite', 'Ġme', 'Ġan', 'Ġac', 'ro', 'stic', 'Ġpoem', 'Ġabout', 'Ġcheese', '.', '<|endofturn|>', 'Ċ', '<|assistant|>']


"Write me an acrostic poem about cheese.\nFat is this kind of cheese,\nIt is so much.\nI love it more than cheese,\nIt's just so much to make\nI love it.\nI love it more than white.\nI love the cheese.\nIt is so small,\nAnd it's so very small,\nI love it.\nI love it more than almond milk.\nIt's so small,\nI love it just the cheese.\nI love it more than white.\nI love it more than white,\nI love it more than white.\nAnd it's so small,\nI love it more than white.\nI love it more than white.\nI love it more than white.\nSo love it more than white.\nI love it more than white,\nI love it more than white.\nI love it with it more than white.\nI love it more than white.\nI love it more than white.\nI love"

In [21]:
chat_completion("When did Michael Jackson pass away?")

['<|bos|>', '<|user|>', 'When', 'Ġdid', 'ĠMichael', 'ĠJackson', 'Ġpass', 'Ġaway', '?', '<|endofturn|>', 'Ċ', '<|assistant|>']


"When did Michael Jackson pass away?\n Michael Jackson pass back home from his home at the University of Pennsylvania. He is the only one person on the board for the largest episodes ever. He is a professional and it was a slip-back who can be played by many of the sports. He is a extract from the area. He has a passionate passion for his ability to be and the world's most well-known hobby.\n\nWhen Michael Jackson passed away, he saw the state team playing against theantly many other State feelings. He is a curious farmer who is a farmer, and he is a regular farmer. He has a passion for his ability to be in the field, and he is a reference to the team. He is a good fit for his ability to be and the state team's ability to be a hero.\n\nWhen Michael Jackson's work with the state team, he is a0Hello. The state has a rich history and a whole number of people who have a great love for the state."

I completed the available tutorials and built a minified GPT-2 clone.
I have issues with the result of the instruction training.
The resulting model can somewhat answer chat messages, but it's very inconfident.
It often answers something unrelated, or repeats the same words.
It can give a coherent answer only around 1 of 3 messages.
With dropout, the validation loss converges at ~4.25 and the training loss at ~3.2
Without dropout, the validation loss converges at ~5 and the training loss at ~0
What causes this? Too small model? Bad training?

The model:
- context length: 1000
- embedding dimensions: 384
- attention heads: 12
- transformer blocks: 12
- FFN width multilier: 4
- vocabulary size: 24k
- params: ~40M

Techniques:
- weight tying
- custom initialized weights
- custom optimizer config with gradient decay groups
- gradient clipping
- cuda optimizations
- temperature: 0.1, top_k: 10
- the chat inference stops when reaching EOS or end_of_turn tokens
- The chat inference adds the generation token

Pre-training:
- dropout: 0
- training tokens: 3B (too large)
- unique tokens: 1.5B
- learning rate schedule: peak 6e-4, min 6e-5
- batch size: 32
- BOS, EOS, end_of_text tokens
- The generated text here is pretty confident, but the repeating still appears sometimes

Instruction training:
- dropout: 0.1
- training chats: 5 * 9500
- unique chats: 9500
- learning rate: fixed, 1e-5
- batch size: 8 (smaller, because it was trained on a different GPU)
- I checked the training data, loss masking, attention masking, special tokens


Examples chats:

- User: When did Michael Jackson pass away?
- Assistant: Michael Jackson's first walked away from the United States on August 14, 1969. He was the second in the first place to walk away from the American States. He was the second to walk away from the country in the country and was the second to walk away from the country. He was also the second to walk away from the country in the country in the country.

- User: Write me an acrostic poem about cheese.
- Assistant: Oh, I love,\nOh, I love,\nOh, how can I love you?\nOh, I love you,\nOh, I love you,\nOh, how can I love you?\nOh, I love you, you love you,\nOh, how can I love you,\nOh, how I love you,\nI love you, love you,\nOh, how love you,\nOh, how love you, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love you, my love,\nOh, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love you,\nOh, how love, I love you,\nOh, how love you, how love you,\nOh, how love you,\nOh

- User: I'm trying to create a menu with different kinds of pasta. Help me come up with different types of pasta and what they are best used for.
- Assistant: Here are a few pasta that are best used for a pasta:\n1. Cauliflower. This pasta is great for pasta and can be used to make it a pasta-together recipe.\n2. Parmesan. This pasta is great for pasta and pasta.\n3. Parmesan. Parmesan is a pasta-together pasta.