In [2]:
import jax
import jax.numpy as jnp
import flax
import tokenizer
import time
from model import Tranformer, GPTConfig #, ChessGPT, cross_entropy_loss
import os
import sys
import optax
from tqdm import tqdm
import pickle
from utils import saveWeights, loadWeights
import numpy as np
# from tokenizer import tokenizeLine
# tokenizer.tokenize()



INT_DTYPE = jnp.int16
FLOAT_DTYPE = jnp.float16
vocab, vocabDecode = tokenizer.makeVocabUCI_SMALL()
PAD_TOKEN = vocab['<PAD>']
nBatches = 10000
BATCH_SIZE = 128
BATCH_SIZE = 64
BLOCK_SIZE = 400
# BLOCK_SIZE = 512
CONTEXT_LENGTH = tokenizer.MAX_MOVES*3+1
RAND_SEED = 123
VOCAB_SIZE = len(vocabDecode)
randKEY = jax.random.PRNGKey(seed=RAND_SEED)


print("Loading Vocab")
gamePath = 'data/ELO_2000_UCI.txt'
print("Opening Games File")
file = open(gamePath, 'r')
print("Reading Games File")
# PROBLEMO -------------------------- HUGE --------------------- PROBLEMO 
# games = file.read(200000000)
games = file.read()
print("Closing Games File")
file.close()
print('Spliting Games')
games = games.splitlines()
print("FNIISHED Spliting Games File")
print('Length of GAMES:',len(games))
# sys.exit()
# games = games[100000:130000]
games = games[:13000]
tokenizedGames = []
print("Tokenizing Games")
for g in tqdm(games):
    # g = g[:min((len(g), 500))]
    arr = jnp.array(tokenizer.tokenizeLine(g, vocab, BLOCK_SIZE, pad=True), dtype=jnp.int16)
    tokenizedGames.append(arr)

print("Converting to jnp array")
JtokenizedGames = jnp.vstack(tokenizedGames)
print("FINISHED converting to jnp array")
config = GPTConfig()
config.vocab_size = VOCAB_SIZE
config.n_layer = 12
config.n_head = 12
config.n_embd = 768
config.dropout = 0.0
config.block_size = CONTEXT_LENGTH
config.bias = True

chessModel = Tranformer(config)



Loading Vocab
Opening Games File
Reading Games File
Closing Games File
Spliting Games
FNIISHED Spliting Games File
Length of GAMES: 100000
Tokenizing Games


100%|██████████| 13000/13000 [00:02<00:00, 4448.99it/s]


Converting to jnp array
FINISHED converting to jnp array


In [31]:
@jax.jit
def getBatchSplit(randKey:jax.random.PRNGKey):
    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    d,t, idxs, randKey = splitGames(batch,randKey)
    return d,t, idxs, randKey

@jax.jit
def getBatch(randKey:jax.random.PRNGKey):

    randKey, k = jax.random.split(randKey)
    idx = jax.random.randint(k, (BATCH_SIZE,), 0, len(JtokenizedGames))
    batch = jnp.take(JtokenizedGames, idx, axis=0)
    return batch, randKey

@jax.jit
def splitGame(x:jnp.array, randKey:jax.random.PRNGKey):
    ind = jnp.argmax(jnp.equal(x, PAD_TOKEN), axis=0)
    idx = jax.random.randint(randKey, (1,), 2, ind)[0]
    maskY = jnp.where(jnp.arange(x.shape[0]) <= idx, 1, 0)
    maskX = jnp.where(jnp.arange(x.shape[0]) < idx, 1, 0)
    return x*maskX, x*maskY, idx
@jax.jit
def splitGames(batch:jnp.array, randKey:jax.random.PRNGKey):
    randKeys = jax.random.split(randKey, BATCH_SIZE)
    randKey, k = jax.random.split(randKey)
    d,t,idxs = jax.vmap(splitGame)(batch,randKeys)
    return d,t, idxs, randKey
@jax.jit
def getLoss(params, d, t, idxs):
    logits = chessModel.apply(params, d)
    logits = logits[:, idxs-1, :]
    t = t[:, idxs]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss
@jax.jit
def getLossOLD(params, d, t):
    maskD = jnp.equal(d, PAD_TOKEN)
    maskD = 1 - maskD
    maskT = jnp.equal(t, PAD_TOKEN)
    maskT = 1 - maskT
    logits = chessModel.apply(params, d)
    logits = logits * maskD[:, :, None]
    t_one_hot = jax.nn.one_hot(t, config.vocab_size)
    t_one_hot = t_one_hot * maskT[:, :, None]
    loss = optax.softmax_cross_entropy(logits, t_one_hot)
    loss = jnp.mean(loss)
    return loss

In [34]:



d = jnp.empty((BATCH_SIZE, BLOCK_SIZE), dtype=INT_DTYPE)
d_size_gb = d.size * d.itemsize / 1024**3
print('JNP Batch GB size',d_size_gb)
# dnp = np.ones((BATCH_SIZE, BLOCK_SIZE), dtype=np.int16)
# input('Cont?')
print('Initializing PARAMS')
params = chessModel.init(jax.random.PRNGKey(0), d)
# print('Casting to PARAMS float16')
# params = jax.tree_map(lambda x: x.astype(jnp.float16), params)
# print('FINISHED Casting PARAMS to float16')
print('Making ADAM Optimizer')
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
print('FINISHED Making ADAM Optimizer')
lossGrad = jax.jit(jax.grad(getLoss))

for i in tqdm(range(10)):
    # b, randKEY = getBatch(randKEY)
    # # d,t = makeTargets(b)
    # d,t, randKEY = splitGames(b,randKEY)
    d,t,idxs, randKEY = getBatchSplit(randKEY)
    print('Getting Loss')
    loss = getLoss(params, d, t, idxs)
    print('Loss TYPE', type(loss), loss)
    print('Getting Grads')
    grads = lossGrad(params, d, t, idxs)
    print('GRAD TYPE', type(grads))

    print('GETTING Updates')
    updates, opt_state = optimizer.update(grads, opt_state, params)
    print('params', type(params), params['params'].keys())
    print('updates', type(updates), updates['params'].keys())

    print('params', type(params), params['params']['blocks_0'])
    print('updates', type(updates), updates['params']['blocks_0'])
    # print(updates)
    # print(params)
    params = optax.apply_updates(params, updates)
    print('POST Updates',type(params), params['params']['blocks_0'])
    # print(params)
    print(i, " | Loss", jnp.mean(loss), randKEY)
    print(d[0, :100])
    print(t[0, :100])


JNP Batch GB size 4.76837158203125e-05
Initializing PARAMS
Making ADAM Optimizer
FINISHED Making ADAM Optimizer


  0%|          | 0/10 [00:00<?, ?it/s]

Getting Loss
Loss TYPE <class 'jaxlib.xla_extension.ArrayImpl'> 8.464714
Getting Grads
GRAD TYPE <class 'dict'>
GETTING Updates


 10%|█         | 1/10 [02:05<18:49, 125.49s/it]

params <class 'dict'> dict_keys(['wte', 'wpe', 'blocks_0', 'blocks_1', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'blocks_10', 'blocks_11', 'layerNorm', 'lm_head'])
updates <class 'dict'> dict_keys(['blocks_0', 'blocks_1', 'blocks_10', 'blocks_11', 'blocks_2', 'blocks_3', 'blocks_4', 'blocks_5', 'blocks_6', 'blocks_7', 'blocks_8', 'blocks_9', 'layerNorm', 'lm_head', 'wpe', 'wte'])
params <class 'dict'> {'norm1': {'layer': {'scale': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

In [27]:
print(grads)

{'params': {'blocks_0': {'attend': {'attend': {'key': {'bias': Array([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, n