In [1]:
import mlx.core as mx
import mlx.data as dx
import mlx.nn as nn
import tiktoken

# Dataloader with a tokenizer and a sliding window

In [2]:
def gpt_dataset_v1(txt, tokenizer, max_length, stride, batch_size, shuffle=True):
    # input_ids = []
    # target_ids = []
    token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
    assert len(token_ids) > max_length, "Number of tokenized inputs must at least be equal to max_length+1"

    # sliding window to chunk the input text with overlaps
    chunks = []
    for i in range(0, len(token_ids) - max_length, stride):
        input_chunk = token_ids[i: i+max_length]
        target_chunk = token_ids[i+1: i+max_length+1]
        chunks.append({
            "input_ids": mx.array(input_chunk),
            "target_ids": mx.array(target_chunk)
        })
        # input_ids.append(mx.array(input_chunk))
        # target_ids.append(mx.array(target_chunk))

    # mlx-data pipeline
    stream = dx.buffer_from_vector(chunks)
    stream = stream.to_stream()
    if shuffle:
        stream = stream.shuffle(buffer_size=len(chunks))
    stream = stream.batch(batch_size)
    # TODO: batched numpy arrays to mlx arrays
    return stream

In [3]:
tokenizer = tiktoken.get_encoding("gpt2")
dataloader = gpt_dataset_v1("this is a test sentence with mlx", tokenizer, 4, 2, 1)

In [4]:
list(dataloader)

[{'target_ids': array([[ 318,  257, 1332, 6827]], dtype=int32),
  'input_ids': array([[5661,  318,  257, 1332]], dtype=int32)},
 {'target_ids': array([[ 1332,  6827,   351, 25962]], dtype=int32),
  'input_ids': array([[ 257, 1332, 6827,  351]], dtype=int32)}]

In [5]:
tokenizer.decode([5661, 318, 257, 1332, 6827])

'this is a test sentence'

In [6]:
with open("./the-verdict.txt", "r") as f:
    txt = f.read()
dataset = gpt_dataset_v1(txt, tokenizer, 256, 128, 
                         batch_size=8, shuffle=True)

In [7]:
data_iter = iter(dataset)
first_batch = next(data_iter)
print(first_batch)

{'target_ids': array([[ 1807,   673,   750, ...,  2900,   656,   257],
       [  550,  1775,   683, ...,   271, 10899,    11],
       [ 1544, 13818,  4622, ...,   286,   616, 12036],
       ...,
       [  616,  4286,   705, ...,   910,   416,  4150],
       [  508,   550, 18459, ...,   198,   198,  3347],
       [ 1165,   881, 40642, ...,   366,  2215,   673]],
      shape=(8, 256), dtype=int32), 'input_ids': array([[  273,  1807,   673, ..., 15185,  2900,   656],
       [  314,   550,  1775, ...,   402,   271, 10899],
       [  198,  1544, 13818, ..., 13476,   286,   616],
       ...,
       [  286,   616,  4286, ...,   470,   910,   416],
       [   11,   508,   550, ...,   526,   198,   198],
       [ 1310,  1165,   881, ...,    13,   366,  2215]],
      shape=(8, 256), dtype=int32)}


In [8]:
second_batch = next(data_iter)
print(second_batch)

{'target_ids': array([[  198,   198,     1, ...,   262,  5385, 41186],
       [ 1234,  8737,   656, ...,   336,  8375,   503],
       [10899,   550,   366, ...,  2745,    11,   314],
       ...,
       [  683,     0,  3226, ...,   616,   835,   286],
       [17728,   257,  8500, ...,   465,  2330, 22645],
       [ 1908,   329,   345, ...,  2474,   198,   198]],
      shape=(8, 256), dtype=int32), 'input_ids': array([[   13,   198,   198, ...,   286,   262,  5385],
       [  284,  1234,  8737, ...,   290,   336,  8375],
       [  271, 10899,   550, ..., 29543,  2745,    11],
       ...,
       [12036,   683,     0, ...,   284,   616,   835],
       [   11, 17728,   257, ...,   422,   465,  2330],
       [  673,  1908,   329, ...,   514,  2474,   198]],
      shape=(8, 256), dtype=int32)}


# Create token embeddings

In [9]:
vocab_size = 6
output_dim = 3
mx.random.seed(123)
embedding_layer = nn.Embedding(vocab_size, output_dim)
embedding_layer.weight, embedding_layer(mx.array([3]))

(array([[0.774268, 0.240581, -0.233984],
        [0.496537, 0.00315234, -0.397442],
        [-1.25292, -0.244347, 0.326495],
        [-0.292979, 0.258776, -0.41039],
        [-0.225358, -0.997256, -0.623246],
        [-0.424824, 0.875692, 0.277775]], dtype=float32),
 array([[-0.292979, 0.258776, -0.41039]], dtype=float32))

In [10]:
input_ids = mx.array([2,3,5,1])
embedding_layer(input_ids)

array([[-1.25292, -0.244347, 0.326495],
       [-0.292979, 0.258776, -0.41039],
       [-0.424824, 0.875692, 0.277775],
       [0.496537, 0.00315234, -0.397442]], dtype=float32)

# Encoding word positions

In [11]:
vocab_size = 50257
output_dim = 256
token_embedding_layer = nn.Embedding(vocab_size, output_dim)

In [12]:
max_length = 4
dataset = gpt_dataset_v1(txt, tokenizer, max_length, max_length, 
                         batch_size=8, shuffle=False)
data_iter = iter(dataset)
first_batch = next(data_iter)
inputs = first_batch["input_ids"]
targets = first_batch["target_ids"]

In [13]:
inputs, inputs.shape, targets, targets.shape

(array([[   40,   367,  2885,  1464],
        [ 1807,  3619,   402,   271],
        [10899,  2138,   257,  7026],
        [15632,   438,  2016,   257],
        [  922,  5891,  1576,   438],
        [  568,   340,   373,   645],
        [ 1049,  5975,   284,   502],
        [  284,  3285,   326,    11]], dtype=int32),
 (8, 4),
 array([[  367,  2885,  1464,  1807],
        [ 3619,   402,   271, 10899],
        [ 2138,   257,  7026, 15632],
        [  438,  2016,   257,   922],
        [ 5891,  1576,   438,   568],
        [  340,   373,   645,  1049],
        [ 5975,   284,   502,   284],
        [ 3285,   326,    11,   287]], dtype=int32),
 (8, 4))

In [14]:
token_embeddings = token_embedding_layer(mx.array(inputs))
token_embeddings[0], token_embeddings.shape

(array([[0.00192652, 0.0283879, 0.0952962, ..., 0.0396189, 0.00657118, -0.0729339],
        [0.0438855, -0.00700499, 0.0120941, ..., -0.093685, 0.00781714, 0.0639005],
        [-0.0585878, 0.0482262, -0.0910247, ..., 0.0116109, 0.000168937, -0.128424],
        [0.0250007, -0.0731869, -0.0576285, ..., 0.0203022, 0.0630568, 0.00691061]], dtype=float32),
 (8, 4, 256))

In [15]:
context_length = max_length
pos_embedding_layer = nn.Embedding(context_length, output_dim)

In [16]:
pos_embeddings = pos_embedding_layer(mx.arange(max_length))
pos_embeddings[0], pos_embeddings.shape

(array([0.0957718, -0.0953044, 0.0774341, ..., -0.118451, -0.0242793, 0.00254553], dtype=float32),
 (4, 256))

In [17]:
input_embeddings = token_embeddings + pos_embeddings
input_embeddings[0], input_embeddings.shape

(array([[0.0976984, -0.0669165, 0.17273, ..., -0.078832, -0.0177081, -0.0703884],
        [0.132454, -0.00132004, 0.0429362, ..., -0.0658143, -0.118927, 0.0793585],
        [0.0234318, 0.198176, -0.0387605, ..., 0.0469349, -0.0657431, -0.277958],
        [0.0258491, -0.0750647, -0.0575229, ..., 0.0165579, 0.152153, 0.0663575]], dtype=float32),
 (8, 4, 256))