In [2]:
import jax
import jax.numpy as jnp
import flax
import tokenizer
import time
from model import Tranformer, GPTConfig #, ChessGPT, cross_entropy_loss
import os
import sys
import optax
from tqdm import tqdm
import pickle
from utils import saveWeights, loadWeights
import numpy as np
# from tokenizer import tokenizeLine
# tokenizer.tokenize()



INT_DTYPE = jnp.int16
FLOAT_DTYPE = jnp.float16
vocab, vocabDecode = tokenizer.makeVocabUCI_SMALL()
PAD_TOKEN = vocab['<PAD>']
nBatches = 10000
BATCH_SIZE = 128
BATCH_SIZE = 64
BLOCK_SIZE = 400
# BLOCK_SIZE = 512
CONTEXT_LENGTH = tokenizer.MAX_MOVES*3+1
RAND_SEED = 123
VOCAB_SIZE = len(vocabDecode)
randKEY = jax.random.PRNGKey(seed=RAND_SEED)


print("Loading Vocab")
gamePath = 'data/ELO_2000_UCI.txt'
print("Opening Games File")
file = open(gamePath, 'r')
print("Reading Games File")
# PROBLEMO -------------------------- HUGE --------------------- PROBLEMO 
# games = file.read(200000000)
games = file.read()
print("Closing Games File")
file.close()
print('Spliting Games')
games = games.splitlines()
print("FNIISHED Spliting Games File")
print('Length of GAMES:',len(games))
# sys.exit()
# games = games[100000:130000]
games = games[:13000]
tokenizedGames = []
print("Tokenizing Games")
for g in tqdm(games):
    # g = g[:min((len(g), 500))]
    arr = jnp.array(tokenizer.tokenizeLine(g, vocab, BLOCK_SIZE, pad=True), dtype=jnp.int16)
    tokenizedGames.append(arr)

print("Converting to jnp array")
JtokenizedGames = jnp.vstack(tokenizedGames)
print("FINISHED converting to jnp array")
config = GPTConfig()
config.vocab_size = VOCAB_SIZE
config.n_layer = 12
config.n_head = 12
config.n_embd = 768
config.dropout = 0.0
config.block_size = CONTEXT_LENGTH
config.bias = True

chessModel = Tranformer(config)



Loading Vocab
Opening Games File
Reading Games File
Closing Games File
Spliting Games
FNIISHED Spliting Games File
Length of GAMES: 100000
Tokenizing Games


100%|██████████| 13000/13000 [00:02<00:00, 4448.99it/s]


Converting to jnp array
FINISHED converting to jnp array


In [31]:
@jax.jit
def getBatchSplit(randKey:jax.random.PRNGKey):
    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    d,t, idxs, randKey = splitGames(batch,randKey)
    return d,t, idxs, randKey

@jax.jit
def getBatch(randKey:jax.random.PRNGKey):

    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    return batch, randKey

@jax.jit
def splitGame(x:jnp.array, randKey:jax.random.PRNGKey):
    ind = jnp.argmax(jnp.equal(x, PAD_TOKEN), axis=0)
    idx = jax.random.randint(randKey, (1,), 2, ind)[0]
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    return x*maskX, x*maskY, idx
@jax.jit
def splitGames(batch:jnp.array, randKey:jax.random.PRNGKey):
    randKeys = jax.random.split(randKey, BATCH_SIZE)
    randKey, k = jax.random.split(randKey)
    d,t,idxs = jax.vmap(splitGame)(batch,randKeys)
    return d,t, idxs, randKey
@jax.jit
def getLoss(params, d, t, idxs):
    logits = chessModel.apply(params, d)
    logits = logits[:, idxs-1, :]
    t = t[:, idxs]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss
@jax.jit
def getLossOLD(params, d, t):
    maskD = jnp.equal(d, PAD_TOKEN)
    maskD = 1 - maskD
    maskT = jnp.equal(t, PAD_TOKEN)
    maskT = 1 - maskT
    logits = chessModel.apply(params, d)
    logits = logits * maskD[:, :, None]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    t_one_hot = t_one_hot * maskT[:, :, None]
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss

In [34]:



d = jnp.empty((BATCH_SIZE, BLOCK_SIZE), dtype=INT_DTYPE)
d_size_gb = d.size * d.itemsize / 1024**3
print('JNP Batch GB size',d_size_gb)
# dnp = np.ones((BATCH_SIZE, BLOCK_SIZE), dtype=np.int16)
# input('Cont?')
print('Initializing PARAMS')
params = chessModel.init(jax.random.PRNGKey(0), d)
# print('Casting to PARAMS float16')
# params = jax.tree_map(lambda x: x.astype(jnp.float16), params)
# print('FINISHED Casting PARAMS to float16')
print('Making ADAM Optimizer')
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
print('FINISHED Making ADAM Optimizer')
lossGrad = jax.jit(jax.grad(getLoss))

for i in tqdm(range(10)):
    # b, randKEY = getBatch(randKEY)
    # # d,t = makeTargets(b)
    # d,t, randKEY = splitGames(b,randKEY)
    d,t,idxs, randKEY = getBatchSplit(randKEY)
    print('Getting Loss')
    loss = getLoss(params, d, t, idxs)
    print('Loss TYPE', type(loss), loss)
    print('Getting Grads')
    grads = lossGrad(params, d, t, idxs)
    print('GRAD TYPE', type(grads))

    print('GETTING Updates')
    updates, opt_state = optimizer.update(grads, opt_state, params)
    print('params', type(params), params['params'].keys())
    print('updates', type(updates), updates['params'].keys())

    print('params', type(params), params['params']['blocks_0'])
    print('updates', type(updates), updates['params']['blocks_0'])
    # print(updates)
    # print(params)
    params = optax.apply_updates(params, updates)
    print('POST Updates',type(params), params['params']['blocks_0'])
    # print(params)
    print(i, " | Loss", jnp.mean(loss), randKEY)
    print(d[0, :100])
    print(t[0, :100])


JNP Batch GB size 4.76837158203125e-05
Initializing PARAMS
Making ADAM Optimizer
FINISHED Making ADAM Optimizer


  0%|          | 0/10 [00:00<?, ?it/s]

Getting Loss
Loss TYPE <class 'jaxlib.xla_extension.ArrayImpl'> 8.464714
Getting Grads
GRAD TYPE <class 'dict'>
GETTING Updates


 10%|█         | 1/10 [02:05<18:49, 125.49s/it]

params <class 'dict'> dict_keys(['wte', 'wpe', 'blocks_0', 'blocks_1', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'blocks_10', 'blocks_11', 'layerNorm', 'lm_head'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'norm1': {'layer': {'scale': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

 20%|██        | 2/10 [03:50<15:06, 113.29s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-1.34300599e-05, -8.98382439e-07,  9.13992517e-06,
        -5.59946648e-06,  4.11767542e-06,  5.56594478e-06,
         4.83532767e-06, -1.24439111e-05,  2.55114333e-06,
         4.87580382e-06,  3.56207397e-06, -1.33064013e-05,
         8.94661025e-06,  2.66646839e-06, -1.86910438e-05,
         3.97741951e-06,  2.06175719e-06, -2.49857840e-06,
         3.97996246e-07, -1.32808054e-07, -4.11510655e-06,
         2.25661165e-06,  1.09893781e-05,  7.53893983e-06,
        -7.21130164e-07, -2.35

 30%|███       | 3/10 [05:34<12:44, 109.21s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-1.85394092e-05, -1.73536830e-06,  1.43581474e-05,
        -7.87860336e-06,  8.90054434e-06,  8.51442928e-06,
         2.70457349e-06, -1.81906653e-05,  2.78521702e-06,
         8.04653973e-06,  6.54908308e-06, -2.20492584e-05,
         1.52562416e-05,  3.72714908e-06, -2.71557292e-05,
         6.22555672e-06,  2.95123323e-06, -2.31714989e-06,
        -9.90622198e-07,  3.59539058e-06, -6.12035319e-06,
         2.33291530e-06,  1.56919377e-05,  1.08885179e-05,
        -2.64172331e-06, -3.82

 40%|████      | 4/10 [07:18<10:42, 107.07s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-2.39625806e-05, -2.09587984e-06,  1.77385755e-05,
        -9.34298259e-06,  1.30228418e-05,  1.09279463e-05,
        -4.42056717e-07, -2.48777251e-05,  4.00426825e-06,
         1.05419094e-05,  9.50466710e-06, -3.05614813e-05,
         1.80747957e-05,  4.05292576e-06, -2.99848780e-05,
         7.29323710e-06,  4.85119426e-06, -1.90122807e-06,
        -1.29708917e-06,  7.28448322e-06, -8.02213890e-06,
         3.98561224e-06,  1.82965086e-05,  1.32796567e-05,
        -4.25593862e-06, -4.65

 50%|█████     | 5/10 [09:02<08:49, 105.95s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-2.78133139e-05,  1.85083968e-06,  2.15041982e-05,
        -1.02765935e-05,  1.84666333e-05,  1.07422848e-05,
        -9.91208935e-07, -2.92061432e-05, -4.30408500e-06,
         1.31554152e-05,  1.90955943e-05, -2.61625519e-05,
         2.09165701e-05,  4.93085054e-06, -3.86515349e-05,
         7.41726990e-06,  2.82898054e-06, -1.60616321e-06,
        -1.09191149e-06,  1.07442274e-05, -1.00569086e-05,
         5.15835518e-06,  2.16163371e-05,  1.16080082e-05,
        -7.08605421e-06, -4.07

 60%|██████    | 6/10 [10:47<07:02, 105.60s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-3.12552256e-05,  5.10080736e-06,  2.42967162e-05,
        -1.07554361e-05,  2.30031292e-05,  1.07567439e-05,
        -1.59429601e-06, -3.25354558e-05, -1.07802289e-05,
         1.53195815e-05,  2.63546099e-05, -2.34857198e-05,
         2.33307928e-05,  5.28949568e-06, -4.49932559e-05,
         7.87667523e-06,  9.24177698e-07, -1.59928572e-06,
        -9.14632835e-07,  1.34931797e-05, -1.16395040e-05,
         6.57622195e-06,  2.39613873e-05,  1.03560951e-05,
        -9.25902350e-06, -3.70

 70%|███████   | 7/10 [12:31<05:15, 105.01s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-3.42341664e-05,  7.67309211e-06,  2.63541169e-05,
        -1.10686515e-05,  2.64909431e-05,  1.07125188e-05,
        -1.91918502e-06, -3.51799499e-05, -1.58518396e-05,
         1.70537569e-05,  3.16669430e-05, -2.13578569e-05,
         2.52233731e-05,  5.59101909e-06, -5.00257920e-05,
         8.20783498e-06, -5.31367277e-07, -1.54300301e-06,
        -6.68546591e-07,  1.54581849e-05, -1.29378423e-05,
         7.91432558e-06,  2.57850970e-05,  9.28880945e-06,
        -1.10315832e-05, -3.36

 80%|████████  | 8/10 [14:14<03:29, 104.66s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-3.68490073e-05,  9.77382933e-06,  2.79450269e-05,
        -1.13585629e-05,  2.92292916e-05,  1.06242305e-05,
        -2.24155133e-06, -3.71701863e-05, -1.99531441e-05,
         1.84047130e-05,  3.58846446e-05, -1.95220327e-05,
         2.67520882e-05,  5.82715347e-06, -5.41674599e-05,
         8.35483843e-06, -1.62213257e-06, -1.53503879e-06,
        -3.41195062e-07,  1.70557250e-05, -1.39548347e-05,
         8.97227437e-06,  2.73183450e-05,  8.38889355e-06,
        -1.24537928e-05, -3.08

 90%|█████████ | 9/10 [15:58<01:44, 104.45s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-3.92437250e-05,  1.16016472e-05,  2.90233165e-05,
        -1.17479703e-05,  3.13177225e-05,  1.04608898e-05,
        -2.30190653e-06, -3.85624662e-05, -2.34642648e-05,
         1.94105869e-05,  3.91912581e-05, -1.78919418e-05,
         2.78124207e-05,  6.14034252e-06, -5.78016479e-05,
         8.32560272e-06, -2.30904129e-06, -1.37348991e-06,
         1.28065835e-07,  1.81523210e-05, -1.49218758e-05,
         9.97203369e-06,  2.87243602e-05,  7.67157962e-06,
        -1.37658608e-05, -2.94

100%|██████████| 10/10 [17:43<00:00, 106.38s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'attend': {'attend': {'key': {'bias': Array([[-4.14728893e-05,  1.33346603e-05,  2.97905826e-05,
        -1.22197944e-05,  3.29054747e-05,  1.03008279e-05,
        -2.15555974e-06, -3.95229181e-05, -2.65509971e-05,
         2.01023540e-05,  4.17153606e-05, -1.62970136e-05,
         2.84902471e-05,  6.53773668e-06, -6.10924180e-05,
         8.09618814e-06, -2.68319127e-06, -1.05205061e-06,
         7.33892421e-07,  1.88717386e-05, -1.59061929e-05,
         1.09438461e-05,  3.01552463e-05,  7.07951494e-06,
        -1.50374581e-05, -2.94




In [27]:
print(grads)

{'params': {'blocks_0': {'attend': {'attend': {'key': {'bias': Array([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, n