In [2]:
import jax
import jax.numpy as jnp
import flax
import tokenizer
import time
from model import Tranformer, GPTConfig #, ChessGPT, cross_entropy_loss
import os
import sys
import optax
from tqdm import tqdm
import pickle
from utils import saveWeights, loadWeights
import numpy as np
# from tokenizer import tokenizeLine
# tokenizer.tokenize()



INT_DTYPE = jnp.int16
FLOAT_DTYPE = jnp.float16
vocab, vocabDecode = tokenizer.makeVocabUCI_SMALL()
PAD_TOKEN = vocab['<PAD>']
nBatches = 10000
BATCH_SIZE = 128
BATCH_SIZE = 64
BLOCK_SIZE = 400
# BLOCK_SIZE = 512
CONTEXT_LENGTH = tokenizer.MAX_MOVES*3+1
RAND_SEED = 123
VOCAB_SIZE = len(vocabDecode)
randKEY = jax.random.PRNGKey(seed=RAND_SEED)


print("Loading Vocab")
gamePath = 'data/ELO_2000_UCI.txt'
print("Opening Games File")
file = open(gamePath, 'r')
print("Reading Games File")
# PROBLEMO -------------------------- HUGE --------------------- PROBLEMO 
# games = file.read(200000000)
games = file.read()
print("Closing Games File")
file.close()
print('Spliting Games')
games = games.splitlines()
print("FNIISHED Spliting Games File")
print('Length of GAMES:',len(games))
# sys.exit()
# games = games[100000:130000]
games = games[:13000]
tokenizedGames = []
print("Tokenizing Games")
for g in tqdm(games):
    # g = g[:min((len(g), 500))]
    arr = jnp.array(tokenizer.tokenizeLine(g, vocab, BLOCK_SIZE, pad=True), dtype=jnp.int16)
    tokenizedGames.append(arr)

print("Converting to jnp array")
JtokenizedGames = jnp.vstack(tokenizedGames)
print("FINISHED converting to jnp array")
config = GPTConfig()
config.vocab_size = VOCAB_SIZE
config.n_layer = 12
config.n_head = 12
config.n_embd = 768
config.dropout = 0.0
config.block_size = CONTEXT_LENGTH
config.bias = True

chessModel = Tranformer(config)



Loading Vocab
Opening Games File
Reading Games File
Closing Games File
Spliting Games
FNIISHED Spliting Games File
Length of GAMES: 100000
Tokenizing Games


100%|██████████| 13000/13000 [00:02<00:00, 4448.99it/s]


Converting to jnp array
FINISHED converting to jnp array


In [24]:
@jax.jit
def getBatchSplit(randKey:jax.random.PRNGKey):
    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    d,t, idxs, randKey = splitGames(batch,randKey)
    return d,t, idxs, randKey

@jax.jit
def getBatch(randKey:jax.random.PRNGKey):

    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    return batch, randKey

@jax.jit
def splitGame(x:jnp.array, randKey:jax.random.PRNGKey):
    ind = jnp.argmax(jnp.equal(x, PAD_TOKEN), axis=0)
    idx = jax.random.randint(randKey, (1,), 2, ind)[0]
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    return x*maskX, x*maskY, idx
@jax.jit
def splitGames(batch:jnp.array, randKey:jax.random.PRNGKey):
    randKeys = jax.random.split(randKey, BATCH_SIZE)
    randKey, k = jax.random.split(randKey)
    d,t,idxs = jax.vmap(splitGame)(batch,randKeys)
    return d,t, idxs, randKey
@jax.jit
def getLoss(params, d, t, idxs):
    logits = chessModel.apply(params, d)
    logits = logits[:, idxs-1, :]
    t = t[:, idxs]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss

In [25]:

from turtle import up


d = jnp.empty((BATCH_SIZE, BLOCK_SIZE), dtype=INT_DTYPE)
d_size_gb = d.size * d.itemsize / 1024**3
print('JNP Batch GB size',d_size_gb)
# dnp = np.ones((BATCH_SIZE, BLOCK_SIZE), dtype=np.int16)
# input('Cont?')
print('Initializing PARAMS')
params = chessModel.init(jax.random.PRNGKey(0), d)
print('Casting to PARAMS float16')
params = jax.tree_map(lambda x: x.astype(jnp.float16), params)
print('FINISHED CASTING')
print('Making ADAM Optimizer')
# optimizer = optax.chain(
#     optax.clip_by_global_norm(1.0),  # Clip gradients to prevent explosion
#     optax.adam(learning_rate=1e-5))

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
lossGrad = jax.grad(getLoss)
print('FINISHED Making ADAM Optimizer')
for i in tqdm(range(10)):
    b, randKEY = getBatch(randKEY)
    # d,t = makeTargets(b)
    d,t, randKEY = splitGames(b,randKEY)
    print('Getting Loss')
    loss = getLoss(params, d, t)
    print('Getting Grads')
    grads = lossGrad(params, d, t)
    print('GRAD TYPE', type(grads))

    print('GETTING Updates')
    updates, opt_state = optimizer.update(grads, opt_state, params)
    print('params', type(params), params['params'].keys())
    print('updates', type(updates), updates['params'].keys())
    # print(updates)
    # print(params)
    params = optax.apply_updates(params, updates)
    # print(type(params))
    # print(params)
    print(i, " | Loss", jnp.mean(loss), randKEY)
    print(d[0, :100])
    print(t[0, :100])


JNP Batch GB size 4.76837158203125e-05
Initializing PARAMS
Casting to PARAMS float16
FINISHED CASTING
Making ADAM Optimizer
FINISHED Making ADAM Optimizer


  0%|          | 0/10 [00:00<?, ?it/s]

Getting Loss
(64, 400)
post menan loss:  ()
Getting Grads
GETTING Updates


 10%|█         | 1/10 [01:45<15:45, 105.04s/it]

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
0  | Loss 1.4941249 [3577238403 1243377846]
[2361 1839 2101 2362  592 1427 2363 1846 1112 2364 1519 1049 2365  893
  795 2366  121  487 2367  673  858 2368  904  556 2369 1201  754 2370
 1513 1175 2371  293 1802 2372 2135 1488 2373  576  742 2374 1286 1450
 2375  613  204 2376 1013 1152 2377  961  472 2378   21  813 2379  667
  195 2380 1287 1450 2381  560  245 2382  728  508 2383  420  198 2384
  869 1245 2385 1611  696 2386 1384 1060 2387 1955 2054 2388 1813 1797
 2389  458 1698 2390  907  845 2391 1130   51 2392 1759 2086 2393  490
  601 2394]
[2361 1839

 20%|██        | 2/10 [02:57<11:25, 85.68s/it] 

params <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
1  | Loss nan [1778655697 1148430379]
[2361 1839 2101 2362 1846 1112 2363 1519  487 2364 1201 1427 2365  903
 1802 2366 1220 1488 2367 1307 1732 2368 1512  796 2369  591  555 2370
  949  412 2371 2136  251 2372  294  403 2373  322 1176 2374  938  138
 2375 1013 1152 2376  374  787 2377  575  133 2378 1196  860 2379 1863
  752 2380  932 1135 2381  881 1452 2382  273 1798 2383 1522 1336 2384
   20   99 2385  972  169 2386  439 1049 2387 1608  750 2388  939 2299
 2389 1206 1739    0    0    0    0    0    0    0    0    0    0    0
    0    0]
[2361 1839 2101 

Exception ignored in: <function _xla_gc_callback at 0x0000010DEFDF9630>
Traceback (most recent call last):
  File "C:\Users\nikhi\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\jax\_src\lib\__init__.py", line 101, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
 20%|██        | 2/10 [04:07<16:28, 123.51s/it]


KeyboardInterrupt: 

In [27]:
print(grads)

{'params': {'blocks_0': {'attend': {'attend': {'key': {'bias': Array([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, n