## https://colab.research.google.com/drive/1ak4kOtbIQGXE5kuhhGTd55xu4qRpeZd7?usp=sharing

## N-Gram Model (AR) Over Bytes

We consider an AR language model over bytes (256 tokens).

- We give it algorithmic training data.
- We can then see how well the model can predict the next token given the context
- Since we know the data generating process, we can see how well the model captures the underlying process.
    - *Spoiler*: It doesn't do well.

- This is a Markov chain on the order of $O(256^n)$ states at maximum, but the
algortihmic data is low-dimensional so it's very *sparse*.

- We represent our $n$-gram model as a dictionary of dictionaries.
    - Outer dictionary is indexed by the context.
    - Inner dictionary is indexed by the next token.
    - Each token given the context maps to the number of times that token was
    observed in the training data.
        - Normalize by the total count to get a probability.

- This is simple model and simple data
    - Hopefully, exploring its properties can help us understand LLMs.

### Finite State Machines

We can view AR-LMs as finite state machines (if deterministic) otherwise
Markvo chains without loss of generality.

- Computers are FSMs, just very large ones.
- LLMs are also very large FSMs.

https://www.lesswrong.com/posts/7qSHKYRnqyrumEfbt

- Thus, AR-LLMs are differentiable computers that can learn from examples.

In [2]:
import random
import pprint

class MarkovChain:
    def __init__(self):
        self.model = {}

    def __str__(self):
        pprint(self.model)

    def percept(self, prev_tokens, next_token):
        if prev_tokens not in self.model:
            self.model[prev_tokens] = {}
        if next_token not in self.model[prev_tokens]:
            self.model[prev_tokens][next_token] = 0
        self.model[prev_tokens][next_token] += 1

    def train(self, data, order = 100):
        N = len(data)
        for i in range(N):
            tokens = data[i]
            m = len(tokens)
            if i % 10000 == 0:
                print(f'{i/N*100} percent complete')
            for j in range(m):
                for k in range(0, order+1):
                    if m <= j+k:
                        break
                    self.percept(prev_tokens = tokens[j:j+k],
                                 next_token = tokens[j+k])

    def distribution(self, ctx):
        if ctx not in self.model:
            return self.distribution(ctx[1:])        
        total = sum(self.model[ctx].values())
        return self.model[ctx] | {k: v / total for k, v in self.model[ctx].items()}
    
    def predict(self, ctx):
        d = self.distribution(ctx)
        r = random.random()
        s = 0
        for token, p in d.items():
            s += p
            if s >= r:
                return token
        raise ValueError('no candidate found')

    def generate(self, ctx, max_tokens = 100, stop_token='.'):
        output = ''
        for i in range(max_tokens):
            token = self.predict(ctx)
            output += token
            if token == stop_token:
                break
            ctx = ctx + token

        return output

### Algorithmic Training Data

Our training data are byte sequences of expression trees.
- The language has a few basic operations: `srt` (sort), `sum`, `min`, and `max`.
- Each operation operations on lists of integers and returns lists of integers.
- We train the model on *random* expression trees.
- We can then use the model to **generate** new expression trees
- We can **prompt** it with a partial expression trees to see how well it predicts the rest of the tree.


![expression-tree](./expr_tree.png)

Here is a depiction of `srt[sum[4,3],max[3,5,2]]=[6,7]`.
    - If we prompt the model with `srt[sum[4,3],max[3,5,2]]=`, we expect it to predict `[6,7]`.

In [3]:
import random
import math

def srt(x):
    return sorted(x)

def default_args():
    return {
        'operations': [sum, sum, min, max, srt],
        'recurse_prob': 0.25,
        'min_child': 1,
        'max_child': 5,
        'values': [1, 2, 3, 4, 5, 6, 7, 8, 9],
        'min_depth': 0,
        'max_depth': 4,
        'debug': False
    }

def generate_data(
        samples=1, 
        stop_tok='.',
        args=default_args()):

    for k in default_args().keys():
        if k not in args:
            args[k] = default_args()[k]

    sample = []
    for _ in range(samples):
        data = generate_tree(
            args['operations'],
            args['recurse_prob'],
            args['min_child'],
            args['max_child'],
            args['values'],
            args['min_depth'],
            args['max_depth'],
            args['debug'])

        result = data['result']
        if isinstance(result, list) and len(result) == 1:
            result = result[0]

        tok_seq = f"{data['expr']}={result}{stop_tok}"
        tok_seq.replace(' ', '')
        sample.append(tok_seq)
    return sample

def generate_tree(operations, recurse_prob, min_child, max_child,
        values, min_depth, max_depth, debug, offset=''):
    
    if max_depth != 0 and (min_depth >= 0 or random.random() < recurse_prob):
        num_nodes = random.randint(min_child, max_child)

        childs = [generate_tree(
            operations=operations,
            recurse_prob=recurse_prob,
            min_child=min_child,
            max_child=max_child,
            values=values,
            min_depth=min_depth-1,
            max_depth=max_depth-1,
            debug=debug,
            offset = offset + '  ') for _ in range(num_nodes)]
        
        child_results = []
        for child in childs:
            if not isinstance(child['result'], list):
                child['result'] = [child['result']]
            child_results.extend(child['result'])
        op = random.choice(operations)   
        res = op(child_results)
        if not isinstance(res, list):
            res = [res]
        expr = f'{op.__name__}[{",".join([child["expr"] for child in childs])}]'
        data = {'result': res, 'expr': expr}
        if debug:
            print(f'{offset}{data=}')
        return data
    else:
        v = random.choice(values)
        data = {'result': v, 'expr': f'{v}'}
        
        if debug:
            print(f'{offset}{data=}')
        return data

Next, we generate some data. It's going to be the simplest data. It's 
an expression tree 1 level deep:

In [4]:
sample = generate_data(
   samples=10,
   args={'debug': False, 'min_depth': 0, 'max_depth': 0})
print(sample)

['6=6.', '2=2.', '5=5.', '7=7.', '8=8.', '1=1.', '6=6.', '7=7.', '7=7.', '4=4.']


Let's generate a bit more complicated data, `operation[list]=list`:

In [5]:
sample = generate_data(
   samples=1000,
   args={'debug': False, 'min_depth': 1, 'max_depth': 1})
print(sample[:3])

['min[6,5,2,2]=2.', 'min[9,8,8,2,1]=1.', 'sum[3,5,7,9]=24.']


Let's train it! This is pretty straight forward to learn, right?

In [6]:
model = MarkovChain()
model.train(sample)

0.0 percent complete


In [7]:
print(model.generate('sum'))
print(model.generate('sum[1'))
print(model.generate('sum[1,2,3]='))

[7,3,8,8,8]=34.
]=1.
3.


We have just learned the most inefficient way to evaluate a simple expression
tree. Let's extend the number of operands!

In [83]:
print(model.generate('sum[1,2,3,4'))
print(model.generate('sum[1,2,3,4,5'))
print(model.generate('sum[1,2,3,4,5'))

]=10.
]=14.
]=[2, 3, 4, 5].


In [8]:
model.model['sum[1,2,3,4,5']

KeyError: 'sum[1,2,3,4,5'

It's never seen this data before. So, what does it do?
It throws away the oldest bytes until it has seen the content.
Then, it predicts the next byte based on that.

In [94]:
print(model.model['2,3,4,5'])
model.generate('2,3,4,5')

{']': 10}


']=14.'

## Inductive Bias

Throwing away oldest bytes is a strong inductive bias.

- Not necessarily true that the next byte is dependent on the oldest bytes.

## Generative Model
- Generate text by starting with a any context and then sampling from the
probability distribution for that context to get the next token.
- Repeat until we have generated the desired number of tokens.
- Same way LLMs work (but they work well).

## Analysis
Our model has some advantages compared to transformer-based AR-LLMs:

- Since we simply *store* the data:
    - Easy to implement.
    - Easy to make it a lifelong learner. Store *more data*.

But, compared to more sophisticated models, they have huge disadvantages:

- $n$-gram model is not able to capture long-range dependencies in the data.
    - Number of states grows exponentially with the order of the model.
    - It cannot scale to large contexts, and therefore cannot understand
    nuances in the data.

- $n$-gram model does not generalize out-of-distribution very well.
    - Since language is a high-dimensional space, *most* contexts have never been seen before.

## Conclusion

A *good* model *compresses* the data.

- This is a key concept in ML.
- There is a notion that *compression* is a proxy for *understanding*.
- Take a *physics simulation*: we don't need to store the position and velocity
of every particle.
- We can just store the starting conditions and then let the laws of physics
play out.
    - Not perfect, but perfectly predicting the future is not possible.
    - We only need to predict it well enough to make good decisions.

`Prediction = compression = intelligence`

The brain may be a good example of this.