## N-Gram Model (AR) Over Bytes

We consider an AR language model over bytes (256 tokens). When we give it
algorithmic training data, we are free to speicfy the max contet length, or
the order of the model, we want to consider. We can then see how well the model
can predict the next token given the context, and how well it can generate
new data.

This is a Markov chain on the order of $O(256^n)$ states at maximum, but the
data is low-dimensional so we can train the model with a relatively high order
$n$.

We represent our n-gram model as a dictionary of dictionaries, where the outer
dictionary is indexed by the context and the inner dictionary is indexed by the
next token and contains the number of times that token was observed in that
context in the training data. (Normalize by the total count to get a probability).

This is a dead simple model, but here we use it to explore the properties of
language models and as a way to understand LLMs.

In [56]:
import random
import pprint

class MarkovChain:
    def __init__(self):
        self.model = {}

    def __str__(self):
        pprint(self.model)

    def percept(self, prev_tokens, next_token):
        if prev_tokens not in self.model:
            self.model[prev_tokens] = {}
        if next_token not in self.model[prev_tokens]:
            self.model[prev_tokens][next_token] = 0
        self.model[prev_tokens][next_token] += 1

    def train(self, data, order = 100):
        N = len(data)
        for i in range(N):
            tokens = data[i]
            m = len(tokens)
            if i % 10000 == 0:
                print(f'{i/N*100} percent complete')
            for j in range(m):
                for k in range(0, order+1):
                    if m <= j+k:
                        break
                    self.percept(prev_tokens = tokens[j:j+k],
                                 next_token = tokens[j+k])

    def distribution(self, ctx):
        if ctx not in self.model:
            return self.distribution(ctx[1:])        
        total = sum(self.model[ctx].values())
        return self.model[ctx] | {k: v / total for k, v in self.model[ctx].items()}
    
    def predict(self, ctx):
        d = self.distribution(ctx)
        r = random.random()
        s = 0
        for token, p in d.items():
            s += p
            if s >= r:
                return token
        raise ValueError('no candidate found')

    def generate(self, ctx, max_tokens = 100, stop_token='.'):
        output = ''
        for i in range(max_tokens):
            token = self.predict(ctx)
            output += token
            if token == stop_token:
                break
            ctx = ctx + token

        return output

### Algorithmic Training Data

Our training data are byte sequences of expression trees in a simple language
that we define. The language has a few basic operations, such as `srt` (sort),
`sum`, `min`, and `max`. Each of these operatiosn takes a list of integers as
input and returns a list of integers as output. We generate random expression
trees in this language, and then use them to train the model. We can then use
the model to generate new expression trees, and if we prompt it with a partial
expression tree, we can see how well it can predict the rest of the tree.

Here is a depiction of `srt[sum[4,3],max[3,5,2]]=[6,7]`. If we prompt the model
with `srt[sum[4,3],max[3,5,2]]=`, we expect it to predict `[6,7]`.

![expression-tree](./expr_tree.png)

Feel free to play around with `generate_data`. It's in the file
`algorithmic_data.py`. We reproduce it in the code below.
You can change the operations, the number of operations, and so on.

In [3]:
import random
import math

def srt(x):
    return sorted(x)

def default_args():
    return {
        'operations': [sum, sum, min, max, srt],
        'recurse_prob': 0.25,
        'min_child': 1,
        'max_child': 5,
        'values': [1, 2, 3, 4, 5, 6, 7, 8, 9],
        'min_depth': 0,
        'max_depth': 4,
        'debug': False
    }

def generate_data(
        samples=1, 
        stop_tok='.',
        args=default_args()):

    for k in default_args().keys():
        if k not in args:
            args[k] = default_args()[k]

    sample = []
    for _ in range(samples):
        data = generate_tree(
            args['operations'],
            args['recurse_prob'],
            args['min_child'],
            args['max_child'],
            args['values'],
            args['min_depth'],
            args['max_depth'],
            args['debug'])

        result = data['result']
        if isinstance(result, list) and len(result) == 1:
            result = result[0]

        tok_seq = f"{data['expr']}={result}{stop_tok}"
        tok_seq.replace(' ', '')
        sample.append(tok_seq)
    return sample

def generate_tree(operations, recurse_prob, min_child, max_child,
        values, min_depth, max_depth, debug, offset=''):
    
    if max_depth != 0 and (min_depth >= 0 or random.random() < recurse_prob):
        num_nodes = random.randint(min_child, max_child)

        childs = [generate_tree(
            operations=operations,
            recurse_prob=recurse_prob,
            min_child=min_child,
            max_child=max_child,
            values=values,
            min_depth=min_depth-1,
            max_depth=max_depth-1,
            debug=debug,
            offset = offset + '  ') for _ in range(num_nodes)]
        
        child_results = []
        for child in childs:
            if not isinstance(child['result'], list):
                child['result'] = [child['result']]
            child_results.extend(child['result'])
        op = random.choice(operations)   
        res = op(child_results)
        if not isinstance(res, list):
            res = [res]
        expr = f'{op.__name__}[{",".join([child["expr"] for child in childs])}]'
        data = {'result': res, 'expr': expr}
        if debug:
            print(f'{offset}{data=}')
        return data
    else:
        v = random.choice(values)
        data = {'result': v, 'expr': f'{v}'}
        
        if debug:
            print(f'{offset}{data=}')
        return data

Next, we generate some data. It's going to be the simplest data. It's 
an expression tree 1 level deep:

In [4]:
sample = generate_data(
   samples=10,
   args={'debug': False, 'min_depth': 0, 'max_depth': 0})
print(sample)

['9=9.', '1=1.', '1=1.', '4=4.', '4=4.', '8=8.', '4=4.', '5=5.', '3=3.', '8=8.']


Let's generate a bit more complicated data, `operation[list]=list`:

In [68]:
sample = generate_data(
   samples=100000,
   args={'debug': False, 'min_depth': 1, 'max_depth': 1})
print(sample[:3])

['min[2,3,6]=2.', 'srt[4,9,4,6]=[4, 4, 6, 9].', 'min[3,7,1]=1.']


Let's train it! This is pretty straight forward to learn, right?

In [69]:
model = MarkovChain()
model.train(sample)

0.0 percent complete
10.0 percent complete
20.0 percent complete
30.0 percent complete
40.0 percent complete
50.0 percent complete
60.0 percent complete
70.0 percent complete
80.0 percent complete
90.0 percent complete


In [76]:
print(model.generate('sum'))
print(model.generate('sum[1'))
print(model.generate('sum[1,2,3]='))

[8,6]=14.
,8,6]=15.
6.


We have just learned the most inefficient way to evaluate a simple expression
tree. Let's extend the number of operands!

In [83]:
print(model.generate('sum[1,2,3,4'))
print(model.generate('sum[1,2,3,4,5'))
print(model.generate('sum[1,2,3,4,5'))

]=10.
]=14.
]=[2, 3, 4, 5].


In [84]:
model.model['sum[1,2,3,4,5']

KeyError: 'sum[1,2,3,4,5'

It's never seen this data before. So, what does it do?
It throws away the oldest bytes until it has seen the content.
Then, it predicts the next byte based on that.

In [94]:
print(model.model['2,3,4,5'])
model.generate('2,3,4,5')

{']': 10}


']=14.'

This throwing away of the oldest bytes is a strong inductive bias. It's not
necessarily true that the next byte is dependent on the oldest bytes. It's
just a simple way to handle unseen data. It's a simple way to handle the
fact that most states are out-of-distribution, particularly in high-dimensional
spaces -- unlike this contrived example for algorithmic data.

We can then generate text by starting with a any context and then sampling from
the probability distribution for that context to get the next token.

Compared to more sophisticated models, like transformer-based models, it
performs poorly. Here's why:

1. The $n$-gram model is not able to capture long-range dependencies in the data
as well, given that the number of states grows exponentially with the order
of the model.

2. The $n$-gram model does not generalize out-of-distribution very well.
Since language is a high-dimensional space, *most* contexts have never been
seen before.

The $n$-gram model does not in practice capture the semantics of a natural
language very well. It is sample inefficient and does not scale to large
contexts.

In our model, we simply *store* the data. This has advantages and disadvantages:

Advantages:
- It's simple and easy to implement.
- It's easy to make it a lifelong learner, because we can simply add new data
  to the model. (This is currently a problem for LLMs, which are not lifelong.)
Disadvantages:
- It's not sample efficient. It requires a lot of data to learn the model.
- It's not scalable. It requires a lot of memory to store the model.
- It doesn't generalize well OOD. A lot of tricks have been tried to improve
  it, but compared to LLMs, they suck.

A *good* model *compresses* the data. This is a key concept in machine learning.
There is a notion that *compression* is a proxy for *understanding*.
Take a physics simulation, for example. We don't need to store the position
and velocity of every particle in the universe. We can just store the
starting conditions and then let the laws of physics play out. It won't be
perfect, but perfectly predicting the future is not possible -- we only need
to predict it well enough to make good decisions.

`Prediction = compression = intelligence`

The brain may be a good example of this.