In [1]:

import jax
import jax.numpy as jnp
import flax
import tokenizer
from tqdm import tqdm
# from tokenizer import tokenizeLine
# tokenizer.tokenize()

vocab, vocabDecode = tokenizer.makeVocabUCI_SMALL()
PAD_TOKEN = vocab['<PAD>']
nBatches = 10000
BATCH_SIZE = 128

BLOCK_SIZE = 400
# BLOCK_SIZE = 512
CONTEXT_LENGTH = tokenizer.MAX_MOVES*3+1
RAND_SEED = 123
VOCAB_SIZE = len(vocabDecode)
randKEY = jax.random.PRNGKey(seed=RAND_SEED)

BATCH_SIZE = 64
nBatches = 10
games = open('data/ELO_2000_UCI.txt', 'r').read()
games = games.splitlines()
games = games[:5000]
tokenizedGames = []

# for g in games:
#     arr = []
#     for e in g.split(' '):
#         arr.append(vocab[e])
#     tokenizedGames.append(arr)

print("Tokenizing Games")
for g in tqdm(games):
    # g = g[:min((len(g), 500))]
    arr = jnp.array(tokenizer.tokenizeLine(g, vocab, BLOCK_SIZE, pad=True), dtype=jnp.int16)
    tokenizedGames.append(arr)

print("Converting to jnp array")
JtokenizedGames = jnp.vstack(tokenizedGames)
print("FINISHED converting to jnp array")

Tokenizing Games


100%|██████████| 100000/100000 [00:23<00:00, 4322.85it/s]


Converting to jnp array
FINISHED converting to jnp array


In [8]:
from model import Tranformer, GPTConfig #, ChessGPT, cross_entropy_loss
import optax
INT_DTYPE = jnp.int16
FLOAT_DTYPE = jnp.float16
config = GPTConfig()
config.vocab_size = VOCAB_SIZE
config.n_layer = 12
config.n_head = 12
config.n_embd = 768
config.dropout = 0.0
config.block_size = CONTEXT_LENGTH
config.bias = True

chessModel = Tranformer(config)
d = jnp.empty((BATCH_SIZE, BLOCK_SIZE), dtype=INT_DTYPE)
d_size_gb = d.size * d.itemsize / 1024**3
print('JNP Batch GB size',d_size_gb)
# dnp = np.ones((BATCH_SIZE, BLOCK_SIZE), dtype=np.int16)
# input('Cont?')
print('Initializing PARAMS')
params = chessModel.init(jax.random.PRNGKey(0), d)
print('Casting to PARAMS float16')
params = jax.tree_map(lambda x: x.astype(jnp.float16), params)
print('FINISHED Casting PARAMS to float16')
print('Making ADAM Optimizer')
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
print('FINISHED Making ADAM Optimizer')

JNP Batch GB size 4.76837158203125e-05
Initializing PARAMS
Casting to PARAMS float16
FINISHED Casting PARAMS to float16
Making ADAM Optimizer
FINISHED Making ADAM Optimizer


In [13]:
# @jax.jit
def getBatchSplit(randKey:jax.random.PRNGKey):
    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    d,t, idxs, randKey = splitGames(batch,randKey)
    return d,t, idxs, randKey

# @jax.jit
def splitGame(x:jnp.array, randKey:jax.random.PRNGKey):
    # global randKEY
    ind = jnp.argmax(jnp.equal(x, PAD_TOKEN), axis=0)
    # randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(randKey, (1,), 2, ind)[0]

    # idx = np.random.randint(2, ind)
    # print(ind, 'with split at', idx)
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    # print(maskY)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    # print(maskX)
    return x*maskX, x*maskY, idx
# @jax.jit
def splitGames(batch:jnp.array, randKey:jax.random.PRNGKey):
    randKeys = jax.random.split(randKey, BATCH_SIZE)
    randKey, k = jax.random.split(randKey)
    d,t,idxs = jax.vmap(splitGame)(batch,randKeys)
    return d,t, idxs, randKey
# @jax.jit
def getLoss(params, d, t, idxs):
    logits = chessModel.apply(params, d)
    logits = logits[:, idxs-1, :]
    t = t[:, idxs]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss

In [14]:
dd,tt, idxs,randKey = getBatchSplit(randKEY)
print(dd.shape, tt.shape)
print(dd[0,:150])
print(tt[0,:150])
print(idxs[0])
print('dd0',dd[0,idxs[0]], 'tt0',tt[0,idxs[0]],'dd1',dd[0,idxs[0]+1], 'tt1',tt[0,idxs[0]+1])

(64, 400) (64, 400)
[2361 1839 2101 2362 1846 2032 2363 1519 1807 2364 1201 1488 2365  903
 1112 2366  294  796 2367 1220  555 2368 1512  486 2369   41 1058 2370
  967 1428 2371  591 1174 2372 1863    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0]
[2361 1839 2101 2362 1846 2032 2363 1519 1807 2364 1201 1488 2365  903
 1112 2366  294  796 2367 1220  555 2368 1512  486 2369   41 1058 2370
  967 1428 2371  591 1174 2372 1863 1796    0    0    0    0    0    0
    0

In [16]:
print(getLoss(params, dd, tt, idxs))

8.396398


In [41]:
import jax
import jax.numpy as jnp
import flax.linen as nn

randKEY = jax.random.PRNGKey(seed=123)

# test = 

def splitGame(x:jnp.array, padToken):
    global randKEY
    ind = jnp.where(x == padToken)[0][0]
    
    randKEY, k = jax.random.split(randKEY)
    idx = jax.random.randint(k, (1,), 2, ind)[0]
    print(ind, 'with split at', idx)
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    # print(maskY)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    # print(maskX)

    # print(ind)
    # print(x[ind])
    # print(x)
    # print(mask)
    return x*maskX, x*maskY
for i in range(3):
    d, t = splitGame(test, 0)
    print(type(d), type(t))
    print(d)
    print(t)
    print('---------------------')
def splitGames(batch:jnp.array):
    d,t = jax.vmap(splitGame)(batch)
    return d,t

5 with split at 3
<class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>
[1 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
---------------------
5 with split at 2
<class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>
[1 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
---------------------
5 with split at 3
<class 'jaxlib.xla_extension.ArrayImpl'> <class 'jaxlib.xla_extension.ArrayImpl'>
[1 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[1 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
---------------------


In [11]:
import time
from model import Tranformer, GPTConfig #, ChessGPT, cross_entropy_loss
import os
import jax
import jax.numpy as jnp
import tokenizer
import sys
import optax
from tqdm import tqdm
import pickle
from utils import saveWeights, loadWeights
import numpy as np



INT_DTYPE = jnp.int16
FLOAT_DTYPE = jnp.float16
vocab, vocabDecode = tokenizer.makeVocabUCI_SMALL()
PAD_TOKEN = vocab['<PAD>']
nBatches = 10000
BATCH_SIZE = 128
BATCH_SIZE = 64
BLOCK_SIZE = 400
# BLOCK_SIZE = 512
CONTEXT_LENGTH = tokenizer.MAX_MOVES*3+1
RAND_SEED = 123
VOCAB_SIZE = len(vocabDecode)
randKEY = jax.random.PRNGKey(seed=RAND_SEED)


print("Loading Vocab")
gamePath = 'data/ELO_2000_UCI.txt'
print("Opening Games File")
file = open(gamePath, 'r')
print("Reading Games File")
# PROBLEMO -------------------------- HUGE --------------------- PROBLEMO 
# games = file.read(200000000)
games = file.read()
print("Closing Games File")
file.close()
print('Spliting Games')
games = games.splitlines()
print("FNIISHED Spliting Games File")
print('Length of GAMES:',len(games))
# sys.exit()
# games = games[100000:130000]
games = games[:13000]
tokenizedGames = []
print("Tokenizing Games")
for g in tqdm(games):
    # g = g[:min((len(g), 500))]
    arr = jnp.array(tokenizer.tokenizeLine(g, vocab, BLOCK_SIZE, pad=True), dtype=jnp.int16)
    tokenizedGames.append(arr)

print("Converting to jnp array")
JtokenizedGames = jnp.vstack(tokenizedGames)
print("FINISHED converting to jnp array")
config = GPTConfig()
config.vocab_size = VOCAB_SIZE
config.n_layer = 12
config.n_head = 12
config.n_embd = 768
config.dropout = 0.0
config.block_size = CONTEXT_LENGTH
config.bias = True

chessModel = Tranformer(config)
d = jnp.empty((BATCH_SIZE, BLOCK_SIZE), dtype=INT_DTYPE)
d_size_gb = d.size * d.itemsize / 1024**3
print('JNP Batch GB size',d_size_gb)
# dnp = np.ones((BATCH_SIZE, BLOCK_SIZE), dtype=np.int16)
# input('Cont?')
print('Initializing PARAMS')
params = chessModel.init(jax.random.PRNGKey(0), d)
print('Casting to PARAMS float16')
params = jax.tree_map(lambda x: x.astype(jnp.float16), params)
def getBatch():
    # k = jax.random.PRNGKey(0)
    global randKEY
    # global JtokenizedGames
    randKEY, k = jax.random.split(randKEY)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    # idx = np.random.randint(0, len(JtokenizedGames), (BATCH_SIZE,))
    batch = jnp.take(JtokenizedGames, idx, axis=0)


    return batch

# @jax.jit
def splitGame(x:jnp.array):
    global randKEY
    ind = jnp.argmax(jnp.equal(x, PAD_TOKEN), axis=0)
    randKEY, k = jax.random.split(randKEY)
    idx = jax.random.randint(k, (1,), 2, ind)[0]

    # idx = np.random.randint(2, ind)
    # print(ind, 'with split at', idx)
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    # print(maskY)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    # print(maskX)
    return x*maskX, x*maskY
# @jax.jit
def splitGames(batch:jnp.array):
    d,t = jax.vmap(splitGame)(batch)
    return d,t

def getLoss(params, d, t):
    maskD = jnp.equal(d, PAD_TOKEN)
    maskD = 1 - maskD
    maskT = jnp.equal(t, PAD_TOKEN)
    maskT = 1 - maskT
    logits = chessModel.apply(params, d)
    logits = logits * maskD[:, :, None]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    t_one_hot = t_one_hot * maskT[:, :, None]
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss


Loading Vocab
Opening Games File
Reading Games File
Closing Games File
Spliting Games
FNIISHED Spliting Games File
Length of GAMES: 100000
Tokenizing Games


100%|██████████| 13000/13000 [00:03<00:00, 4196.10it/s]


Converting to jnp array
FINISHED converting to jnp array
JNP Batch GB size 4.76837158203125e-05
Initializing PARAMS
Casting to PARAMS float16


In [12]:
b = getBatch()
print('Batch Shape', b.shape)
print('Batch', b[0])
d,t = splitGames(b)
print('d Shape', d.shape)
print('t Shape', t.shape)
loss = getLoss(params, d, t)
print('Loss Shape', loss.shape)
print('Loss', loss)

Batch Shape (64, 400)
Batch [2361 1839 2101 2362  592 2032 2363  293 1112 2364  696 1731 2365 1220
 1064 2366  929 1162 2367 1196 1739 2368 2136 1428 2369 2206 1995 2370
  576  556 2371 1614 1152 2372 1535  205 2373 1523  860 2374  891 1803
 2375 2117 1472 2376 2163 1958 2377 1528  253 2378 1918 1147 2379 1325
 1921 2380  886 1442 2381 2104 1343 2382   22  750 2383  702 1322 2384
  870 1170 2385 1581 2345 2386  620 1133 2387 1281 2078 2388  963 1488
 2389 1890 2078 2390 1846  782 2391  932 1058 2392 2166 1931 2393 2148
 1325 2394  905 1688 2395 1044 1698 2396  887  795 2397 1203 2087 2398
 1236 1487 2399  887 1795 2400 1236 1487 2401  147 1784 2402  214 1755
 2403  891 1943 2404 1552 1129 2405  620 1442 2406  891 1129 2407  620
    5    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    

In [13]:
b = getBatch()
print('Batch Shape', b.shape)
print('Batch', b[0])
d,t = splitGames(b)
print('d Shape', d.shape)
print('t Shape', t.shape)
loss = getLoss(params, d, t)
print('Loss Shape', loss.shape)
print('Loss', loss)

Batch Shape (64, 400)
Batch [2361 1839 2101 2362  592 1427 2363 1846 1112 2364 1519 1049 2365  893
  860 2366  121  796 2367 1611 1167 2368  702  555 2369 1381 1732 2370
  972  791 2371  434 1802 2372 1219 1022 2373 1201 1488 2374  292 1150
 2375  131 1796 2376  903  860 2377 1513  251 2378  104  484 2379   42
  486 2380  575  204 2381  929 1126 2382  367 1156 2383  663 1453 2384
  300 1062 2385  604  449 2386  127  195 2387 1872  409 2388  706  794
 2389   21  817 2390  344  735 2391 1208 1468 2392  561  474 2393  401
  544 2394  882  214 2395  758 1128 2396  167  412 2397   14  240 2398
  569  219 2399 1265 1064 2400 1219  509 2401   14 1386 2402 1813  861
 2403  372 1401 2404  748 1798 2405 1278  544 2406  219 2298 2407  373
 1092 2408  939  667 2409 1582  448 2410  710  737 2411 2136 1092 2412
 1686  714 2413 1112 2092 2414 1049  662 2415  488 1796 2416  384 1171
 2417  505 1784 2418  698  929 2419 1749 2334 2420 1786 2326 2421 2206
  655 2422 1890  992 2423 1683 1095 2424 1928 146

In [15]:
print(d[0, :150])
print(t[0, :150])

[2361 1839 2101 2362  592 1427 2363 1846 1112 2364 1519 1049 2365  893
  860 2366  121  796 2367 1611 1167 2368  702  555 2369 1381 1732 2370
  972  791 2371  434 1802 2372 1219 1022 2373 1201 1488 2374  292 1150
 2375  131 1796 2376  903    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0]
[2361 1839 2101 2362  592 1427 2363 1846 1112 2364 1519 1049 2365  893
  860 2366  121  796 2367 1611 1167 2368  702  555 2369 1381 1732 2370
  972  791 2371  434 1802 2372 1219 1022 2373 1201 1488 2374  292 1150
 2375  131 1796 2376  903

In [18]:
import jax.numpy as jnp
import jax
import functools
from functools import reduce
# Example function to square leaf elements
def square(x):
    return x ** 2

# Example nested structure
nested_structure = {
    'a': jnp.array([[1, 2, 3],
                   [0, 0, 0],
                   [0, 0, 0]]),
    'b': jnp.array([[6, 7, 8],
                   [0, 0, 0],
                   [0, 0, 0]]),
    'c': {
        'x': jnp.array([[7, 8],
                       [0,0]],
                       [0, 0]),
        'y': jnp.array([[7, 8],
                       [0,0],
                       [0,0]])
    }
}
nested_structure1 = {
    'a': jnp.array([1, 2, 3]),
    'b': jnp.array([6, 7, 8]),
    'c': {
        'x': jnp.array([7, 8]),
        'y': jnp.array([7, 8])
    }
}

nested_structure2 = {
    'a': jnp.array([0, 0, 0]),
    'b': jnp.array([0, 0, 0]),
    'c': {
        'x': jnp.array([0,0]),
        'y': jnp.array([0,0])
    }
}
nested_structure3 = {
    'a': jnp.array([0, 0, 0]),
    'b': jnp.array([0, 0, 0]),
    'c': {
        'x': jnp.array([0,0]),
        'y': jnp.array([0,0])
    }
}
def stack_dicts(d1, d2):
    return jax.tree_map(lambda x, y: jnp.vstack((x, y)), d1, d2)

# Stack the dictionaries along a new axis using jax.vmap
list_of_dicts = [nested_structure1, nested_structure2, nested_structure3]
stacked_dicts = reduce(stack_dicts, list_of_dicts)
print(stacked_dicts)
# stacked_structure = jax.tree_map(jnp.vstack, nested_structure1, nested_structure2)

# Apply the square function using jax.tree_map
meanFn = functools.partial(jnp.mean, axis=0)
transformed_structure = jax.tree_map(meanFn, nested_structure)
tesst = jax.tree_map(meanFn, stacked_dicts)

# Print the original and transformed structures
print("Original structure:")
print(nested_structure)
print("\nTransformed structure:")
print(transformed_structure)
print(tesst)

{'a': Array([[1, 2, 3],
       [0, 0, 0],
       [0, 0, 0]], dtype=int32), 'b': Array([[6, 7, 8],
       [0, 0, 0],
       [0, 0, 0]], dtype=int32), 'c': {'x': Array([[7, 8],
       [0, 0],
       [0, 0]], dtype=int32), 'y': Array([[7, 8],
       [0, 0],
       [0, 0]], dtype=int32)}}
Original structure:
{'a': Array([[1, 2, 3],
       [0, 0, 0]], dtype=int32), 'b': Array([[6, 7, 8],
       [0, 0, 0]], dtype=int32), 'c': {'x': Array([[7, 8],
       [0, 0]], dtype=int32), 'y': Array([[7, 8],
       [0, 0]], dtype=int32)}}

Transformed structure:
{'a': Array([0.5, 1. , 1.5], dtype=float32), 'b': Array([3. , 3.5, 4. ], dtype=float32), 'c': {'x': Array([3.5, 4. ], dtype=float32), 'y': Array([3.5, 4. ], dtype=float32)}}
{'a': Array([0.33333334, 0.6666667 , 1.        ], dtype=float32), 'b': Array([2.       , 2.3333335, 2.6666667], dtype=float32), 'c': {'x': Array([2.3333335, 2.6666667], dtype=float32), 'y': Array([2.3333335, 2.6666667], dtype=float32)}}
