# [1] Setting dataset

In [1]:
from Dataset import loadingFiles, cut_label, truncated_data, code_to_id, RETAIN_datasets
MIMICPath = 'Insert_into_your_data_path'

# Loading
treat_add_id = loadingFiles(MIMICPath, 'Final_treat_dx.pkl')
comparator_add_id = loadingFiles(MIMICPath, 'Final_comparator_dx.pkl')
treat_add_label = [seq for pid, seq in treat_add_id]

# Cutting label
treat = cut_label(treat_add_label)
comparator = [seq for pid, seq in comparator_add_id]

# Cut length <=2
print('\n## Cutting length one more..')
print('before: ', 'treat: ', len(treat), 'comparator: ', len(comparator))
treat = truncated_data(treat, cut_num=2)
comparator = truncated_data(comparator, cut_num=2)
print('after: ', 'treat: ', len(treat), 'comparator: ', len(comparator))
CKD_code_dict, max_visit_size = code_to_id(treat+comparator)


# MIMIC directory
import os
DATA_PATH = 'Insert_into_your_data_path'
DATA_PATH = os.path.join(DATA_PATH, 'MIMIC')
if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH)

# Reading MIMIC dataset
datasets = RETAIN_datasets(treat, comparator, CKD_code_dict, max_visit_size)

Loading at.. /Final_treat_dx.pkl
Loading at.. /Final_comparator_dx.pkl

## Cutting length one more..
before:  treat:  275 comparator:  1067
after:  treat:  275 comparator:  1067
code_size:  2872


# [2] Setting hyperparameter

In [2]:
batch_size = 10
time_size = max_visit_size
code_size = len(CKD_code_dict)
label_size = 2
embedding_size = 128
lr_init = 0.0001
decay_step = 2000
decay_rate = 0.9
training_step = 5000
printby = 100

hidden_size_alpha = 128
hidden_size_beta = 128

# [3] Setting model

In [3]:
import tensorflow as tf

def sequence_masking(data, visit_times):
    masking = tf.tile(tf.reshape(tf.sequence_mask(visit_times, data.shape[1]), shape=[-1,data.shape[1],1]), [1,1,data.shape[2]])
    return tf.where(masking, data, tf.zeros_like(data))

inputs = tf.placeholder(tf.float32, shape=[None, time_size, code_size])
labels = tf.placeholder(tf.float32, shape=[None, time_size, label_size])
visit_times = tf.placeholder(tf.float32, shape=[None])
global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
lr = tf.train.exponential_decay(lr_init, global_step, decay_step, decay_rate, staircase=True)

##embedding
W_emb = tf.Variable(tf.random_normal(shape=[code_size, embedding_size]), name='W_emb')
b_emb = tf.Variable(tf.random_normal(shape=[embedding_size]), name='b_emb')
reshaped_inputs = tf.reshape(inputs, shape=[-1, code_size])
embedded = tf.matmul(reshaped_inputs, W_emb) + b_emb
embedded = tf.reshape(embedded, shape=[-1, time_size, embedding_size])
v_ = tf.nn.tanh(embedded)

##alpha
W_alpha = tf.Variable(tf.random_normal(shape=[hidden_size_alpha, 1], name='W_alpha'))
b_alpha = tf.Variable(tf.random_normal(shape=[], name='b_alpha'))
with tf.variable_scope('alpha'):
    cells_alpha = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(hidden_size_alpha)])
    outputs_alpha, states_alpha = tf.nn.dynamic_rnn(cells_alpha, v_, visit_times, dtype=tf.float32)
    reshaped_outputs_alpha = tf.reshape(outputs_alpha, shape=[-1, hidden_size_alpha])
    matmuled_alpha = tf.matmul(reshaped_outputs_alpha, W_alpha) + tf.expand_dims(b_alpha, 0)
    reshaped_matmuled_alpha = tf.reshape(matmuled_alpha, shape=[-1, time_size, 1])
    logits_alpha = tf.nn.softmax(reshaped_matmuled_alpha, 1)
    
##Beta
W_beta = tf.Variable(tf.random_normal(shape=[hidden_size_beta, code_size]))
b_beta = tf.Variable(tf.random_normal(shape=[code_size]))
with tf.variable_scope('beta'):
    cells_beta = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(hidden_size_beta)])
    outputs_beta, states_beta = tf.nn.dynamic_rnn(cells_beta, v_, visit_times, dtype=tf.float32)
    reshaped_outputs_beta = tf.reshape(outputs_beta, shape=[-1, hidden_size_beta])
    matmuled_ouputs_beta = tf.matmul(reshaped_outputs_beta, W_beta) + tf.expand_dims(b_beta, 0)
    reshaped_matmuled_beta = tf.reshape(matmuled_ouputs_beta, shape=[-1, time_size, code_size])
    logits_beta = tf.nn.tanh(reshaped_matmuled_beta)
    
##unifiy
W_s = tf.Variable(tf.random_normal(shape=[code_size, label_size]))
b_s = tf.Variable(tf.random_normal(shape=[label_size]))
unified_context = logits_alpha*logits_beta*inputs
context_vec = []
for b in range(batch_size):
    pid_seq = unified_context[b]
    for t in range(time_size):
        context_vec.append(tf.reduce_sum(pid_seq[:t+1, :], axis=0))
reshaped_context_vec = tf.reshape(context_vec, shape=[-1, code_size])
mlp_context_vec = tf.matmul(reshaped_context_vec, W_s) + b_s
logits = tf.nn.softmax(tf.reshape(mlp_context_vec, shape=[-1, time_size, label_size]))
masked_logits = sequence_masking(logits, visit_times)
masked_labels = sequence_masking(labels, visit_times)

##Loss function
loss_per_times = tf.reduce_sum(masked_labels*tf.log(masked_logits+1e-10)+(1-masked_labels)*tf.log(1-masked_logits+1e-10), axis=-1)
loss_per_pid = tf.reduce_sum(loss_per_times, axis=-1)*(1/visit_times)
last_loss = -tf.reduce_sum(loss_per_pid)*(1/batch_size)
optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(last_loss, global_step=global_step)

# [4] Training & Testing

In [4]:
from tqdm import trange
with tf.Session() as sess:
    tf.train.start_queue_runners(sess=sess)
    tf.global_variables_initializer().run(session=sess)

    for step in trange(training_step):
        batch_train_inputs, batch_train_labels, batch_train_visit_times = datasets.train.next_batch(batch_size)
        batch_val_inputs, batch_val_labels, batch_val_visit_times = datasets.validation.next_batch(batch_size)
        batch_te_inputs, batch_te_labels, batch_te_visit_times = datasets.test.next_batch(batch_size)
        train_feed_dict = {inputs: batch_train_inputs, 
                           labels: batch_train_labels, 
                           visit_times: batch_train_visit_times}
        validation_feed_dict = {inputs: batch_val_inputs, 
                                labels: batch_val_labels, 
                                visit_times: batch_val_visit_times}
        test_feed_dict = {inputs: batch_te_inputs, 
                          labels: batch_te_labels, 
                          visit_times: batch_te_visit_times}
        _, g_step, train_loss, train_lr_ = sess.run([optimizer, global_step, last_loss, lr], feed_dict=train_feed_dict)
        val_loss, val_lr_ = sess.run([last_loss, lr], feed_dict=validation_feed_dict)
        
        if step % printby == 0:
            print('step: {} \tg_step: {} \tloss: {:.8f}||{:.8f} \tlr: {:.8f}||{:.8f}'.format(step, g_step, train_loss, val_loss, train_lr_, val_lr_))
        
        if step == training_step-1:
            te_loss = sess.run(last_loss, feed_dict=test_feed_dict)
            print('All Done! test_loss is {}'.format(te_loss))

  0%|          | 2/5000 [00:03<3:31:32,  2.54s/it]

step: 0 	g_step: 1 	loss: 1.77771223||1.76532292 	lr: 0.00010000||0.00010000


  2%|▏         | 101/5000 [00:25<17:01,  4.80it/s]

step: 100 	g_step: 101 	loss: 1.36491549||1.72819364 	lr: 0.00010000||0.00010000


  4%|▍         | 201/5000 [00:44<15:15,  5.24it/s]

step: 200 	g_step: 201 	loss: 1.85386586||1.65774083 	lr: 0.00010000||0.00010000


  6%|▌         | 302/5000 [01:05<13:01,  6.01it/s]

step: 300 	g_step: 301 	loss: 1.31635725||1.54219186 	lr: 0.00010000||0.00010000


  8%|▊         | 402/5000 [01:24<13:46,  5.56it/s]

step: 400 	g_step: 401 	loss: 0.59791499||1.75903285 	lr: 0.00010000||0.00010000


 10%|█         | 502/5000 [01:41<12:38,  5.93it/s]

step: 500 	g_step: 501 	loss: 0.69227809||3.00949264 	lr: 0.00010000||0.00010000


 12%|█▏        | 601/5000 [02:01<16:15,  4.51it/s]

step: 600 	g_step: 601 	loss: 0.94558281||1.21598983 	lr: 0.00010000||0.00010000


 14%|█▍        | 702/5000 [02:20<13:44,  5.21it/s]

step: 700 	g_step: 701 	loss: 0.53898090||2.34240532 	lr: 0.00010000||0.00010000


 16%|█▌        | 802/5000 [02:40<14:09,  4.94it/s]

step: 800 	g_step: 801 	loss: 0.63876337||1.79580045 	lr: 0.00010000||0.00010000


 18%|█▊        | 902/5000 [03:00<14:24,  4.74it/s]

step: 900 	g_step: 901 	loss: 0.22118559||2.39732504 	lr: 0.00010000||0.00010000


 20%|██        | 1002/5000 [03:17<11:16,  5.91it/s]

step: 1000 	g_step: 1001 	loss: 0.12590156||2.04432082 	lr: 0.00010000||0.00010000


 22%|██▏       | 1102/5000 [03:35<11:19,  5.73it/s]

step: 1100 	g_step: 1101 	loss: 0.48554587||2.01463819 	lr: 0.00010000||0.00010000


 24%|██▍       | 1202/5000 [03:52<10:50,  5.84it/s]

step: 1200 	g_step: 1201 	loss: 0.31593037||3.11750269 	lr: 0.00010000||0.00010000


 26%|██▌       | 1302/5000 [04:09<10:13,  6.03it/s]

step: 1300 	g_step: 1301 	loss: 0.13908041||3.32683682 	lr: 0.00010000||0.00010000


 28%|██▊       | 1402/5000 [04:27<10:32,  5.69it/s]

step: 1400 	g_step: 1401 	loss: 0.30810985||1.73978806 	lr: 0.00010000||0.00010000


 30%|███       | 1502/5000 [04:45<10:28,  5.57it/s]

step: 1500 	g_step: 1501 	loss: 0.20335813||3.86767006 	lr: 0.00010000||0.00010000


 32%|███▏      | 1602/5000 [05:02<09:39,  5.86it/s]

step: 1600 	g_step: 1601 	loss: 0.15757908||2.33901954 	lr: 0.00010000||0.00010000


 34%|███▍      | 1702/5000 [05:19<09:43,  5.66it/s]

step: 1700 	g_step: 1701 	loss: 0.07970976||2.02240157 	lr: 0.00010000||0.00010000


 36%|███▌      | 1802/5000 [05:37<09:19,  5.71it/s]

step: 1800 	g_step: 1801 	loss: 0.32849309||2.30527020 	lr: 0.00010000||0.00010000


 38%|███▊      | 1902/5000 [05:54<08:48,  5.86it/s]

step: 1900 	g_step: 1901 	loss: 0.04720712||1.66699970 	lr: 0.00010000||0.00010000


 40%|████      | 2002/5000 [06:12<09:54,  5.04it/s]

step: 2000 	g_step: 2001 	loss: 0.04436038||1.74172020 	lr: 0.00009000||0.00009000


 42%|████▏     | 2102/5000 [06:33<09:06,  5.30it/s]

step: 2100 	g_step: 2101 	loss: 0.05007234||2.74863601 	lr: 0.00009000||0.00009000


 44%|████▍     | 2201/5000 [06:55<09:21,  4.99it/s]

step: 2200 	g_step: 2201 	loss: 0.07125783||0.78394407 	lr: 0.00009000||0.00009000


 46%|████▌     | 2301/5000 [07:14<09:10,  4.90it/s]

step: 2300 	g_step: 2301 	loss: 0.05568868||4.71728277 	lr: 0.00009000||0.00009000


 48%|████▊     | 2401/5000 [07:35<08:46,  4.94it/s]

step: 2400 	g_step: 2401 	loss: 0.06120616||4.27806234 	lr: 0.00009000||0.00009000


 50%|█████     | 2501/5000 [07:56<07:32,  5.53it/s]

step: 2500 	g_step: 2501 	loss: 0.22462535||1.62167060 	lr: 0.00009000||0.00009000


 52%|█████▏    | 2602/5000 [08:17<08:30,  4.70it/s]

step: 2600 	g_step: 2601 	loss: 0.21704070||3.68326306 	lr: 0.00009000||0.00009000


 54%|█████▍    | 2701/5000 [08:40<08:59,  4.26it/s]

step: 2700 	g_step: 2701 	loss: 0.08532182||3.27630782 	lr: 0.00009000||0.00009000


 56%|█████▌    | 2801/5000 [09:01<07:57,  4.61it/s]

step: 2800 	g_step: 2801 	loss: 0.05612302||4.65197897 	lr: 0.00009000||0.00009000


 58%|█████▊    | 2902/5000 [09:21<06:14,  5.60it/s]

step: 2900 	g_step: 2901 	loss: 0.34948453||3.44216323 	lr: 0.00009000||0.00009000


 60%|██████    | 3002/5000 [09:39<05:43,  5.81it/s]

step: 3000 	g_step: 3001 	loss: 0.01574058||4.29314375 	lr: 0.00009000||0.00009000


 62%|██████▏   | 3102/5000 [09:57<05:45,  5.49it/s]

step: 3100 	g_step: 3101 	loss: 0.01038979||2.22397399 	lr: 0.00009000||0.00009000


 64%|██████▍   | 3202/5000 [10:15<05:19,  5.63it/s]

step: 3200 	g_step: 3201 	loss: 0.04243891||3.74551129 	lr: 0.00009000||0.00009000


 66%|██████▌   | 3301/5000 [10:33<06:00,  4.71it/s]

step: 3300 	g_step: 3301 	loss: 0.00919096||3.54667592 	lr: 0.00009000||0.00009000


 68%|██████▊   | 3401/5000 [10:52<04:51,  5.49it/s]

step: 3400 	g_step: 3401 	loss: 0.04248129||2.14580894 	lr: 0.00009000||0.00009000


 70%|███████   | 3502/5000 [11:11<04:16,  5.83it/s]

step: 3500 	g_step: 3501 	loss: 0.01489031||3.61428308 	lr: 0.00009000||0.00009000


 72%|███████▏  | 3602/5000 [11:29<04:31,  5.16it/s]

step: 3600 	g_step: 3601 	loss: 0.05693358||2.38219357 	lr: 0.00009000||0.00009000


 74%|███████▍  | 3701/5000 [11:47<03:56,  5.49it/s]

step: 3700 	g_step: 3701 	loss: 0.02754932||4.60098696 	lr: 0.00009000||0.00009000


 76%|███████▌  | 3801/5000 [12:06<03:58,  5.04it/s]

step: 3800 	g_step: 3801 	loss: 0.00848300||4.76792526 	lr: 0.00009000||0.00009000


 78%|███████▊  | 3902/5000 [12:27<03:31,  5.19it/s]

step: 3900 	g_step: 3901 	loss: 0.02832838||0.56622380 	lr: 0.00009000||0.00009000


 80%|████████  | 4002/5000 [12:47<03:05,  5.37it/s]

step: 4000 	g_step: 4001 	loss: 0.01175793||4.08318329 	lr: 0.00008100||0.00008100


 82%|████████▏ | 4102/5000 [13:05<02:46,  5.39it/s]

step: 4100 	g_step: 4101 	loss: 0.00720588||2.59575009 	lr: 0.00008100||0.00008100


 84%|████████▍ | 4202/5000 [13:24<02:28,  5.39it/s]

step: 4200 	g_step: 4201 	loss: 0.01241960||4.27499008 	lr: 0.00008100||0.00008100


 86%|████████▌ | 4302/5000 [13:41<01:56,  5.98it/s]

step: 4300 	g_step: 4301 	loss: 0.18100224||2.99268866 	lr: 0.00008100||0.00008100


 88%|████████▊ | 4402/5000 [14:02<01:50,  5.39it/s]

step: 4400 	g_step: 4401 	loss: 0.03394793||2.76620436 	lr: 0.00008100||0.00008100


 90%|█████████ | 4502/5000 [14:22<01:25,  5.79it/s]

step: 4500 	g_step: 4501 	loss: 0.00668651||2.98700714 	lr: 0.00008100||0.00008100


 92%|█████████▏| 4601/5000 [14:42<01:26,  4.63it/s]

step: 4600 	g_step: 4601 	loss: 0.01399853||1.99717009 	lr: 0.00008100||0.00008100


 94%|█████████▍| 4701/5000 [15:02<00:56,  5.25it/s]

step: 4700 	g_step: 4701 	loss: 0.04530440||1.86162817 	lr: 0.00008100||0.00008100


 96%|█████████▌| 4802/5000 [15:19<00:33,  5.91it/s]

step: 4800 	g_step: 4801 	loss: 0.01429659||0.76042825 	lr: 0.00008100||0.00008100


 98%|█████████▊| 4902/5000 [15:36<00:15,  6.13it/s]

step: 4900 	g_step: 4901 	loss: 0.04870249||3.22787261 	lr: 0.00008100||0.00008100


100%|██████████| 5000/5000 [15:52<00:00,  2.99it/s]

All Done! test_loss is 2.0485293865203857





# [5] Training Reversed dataset

In [6]:
from Dataset import reverse_seq

rvrsed_treat = reverse_seq(treat)
rvrsed_comparator = reverse_seq(comparator)

rvrsed_datasets = RETAIN_datasets(rvrsed_treat, rvrsed_comparator, CKD_code_dict, max_visit_size)

In [7]:
from tqdm import trange
with tf.Session() as sess:
    tf.train.start_queue_runners(sess=sess)
    tf.global_variables_initializer().run(session=sess)

    for step in trange(training_step):
        batch_train_inputs, batch_train_labels, batch_train_visit_times = rvrsed_datasets.train.next_batch(batch_size)
        batch_val_inputs, batch_val_labels, batch_val_visit_times = rvrsed_datasets.validation.next_batch(batch_size)
        batch_te_inputs, batch_te_labels, batch_te_visit_times = rvrsed_datasets.test.next_batch(batch_size)
        train_feed_dict = {inputs: batch_train_inputs, 
                           labels: batch_train_labels, 
                           visit_times: batch_train_visit_times}
        validation_feed_dict = {inputs: batch_val_inputs, 
                                labels: batch_val_labels, 
                                visit_times: batch_val_visit_times}
        test_feed_dict = {inputs: batch_te_inputs, 
                          labels: batch_te_labels, 
                          visit_times: batch_te_visit_times}
        _, g_step, train_loss, train_lr_ = sess.run([optimizer, global_step, last_loss, lr], feed_dict=train_feed_dict)
        val_loss, val_lr_ = sess.run([last_loss, lr], feed_dict=validation_feed_dict)
        
        if step % printby == 0:
            print('step: {} \tg_step: {} \tloss: {:.8f}||{:.8f} \tlr: {:.8f}||{:.8f}'.format(step, g_step, train_loss, val_loss, train_lr_, val_lr_))
        
        if step == training_step-1:
            te_loss = sess.run(last_loss, feed_dict=test_feed_dict)
            print('All Done! test_loss is {}'.format(te_loss))

  0%|          | 2/5000 [00:03<3:00:17,  2.16s/it]

step: 0 	g_step: 1 	loss: 1.89690912||1.69452178 	lr: 0.00010000||0.00010000


  2%|▏         | 102/5000 [00:20<14:16,  5.72it/s]

step: 100 	g_step: 101 	loss: 1.48023438||1.74885869 	lr: 0.00010000||0.00010000


  4%|▍         | 202/5000 [00:40<15:56,  5.01it/s]

step: 200 	g_step: 201 	loss: 1.44120491||1.43612981 	lr: 0.00010000||0.00010000


  6%|▌         | 302/5000 [00:59<12:57,  6.04it/s]

step: 300 	g_step: 301 	loss: 1.41969240||1.63030529 	lr: 0.00010000||0.00010000


  8%|▊         | 402/5000 [01:16<12:45,  6.01it/s]

step: 400 	g_step: 401 	loss: 1.29558361||1.64614201 	lr: 0.00010000||0.00010000


 10%|█         | 502/5000 [01:32<12:09,  6.17it/s]

step: 500 	g_step: 501 	loss: 0.69793469||2.37376475 	lr: 0.00010000||0.00010000


 12%|█▏        | 602/5000 [01:49<11:50,  6.19it/s]

step: 600 	g_step: 601 	loss: 0.92424345||1.82619727 	lr: 0.00010000||0.00010000


 14%|█▍        | 702/5000 [02:05<11:49,  6.06it/s]

step: 700 	g_step: 701 	loss: 0.80797809||2.69602275 	lr: 0.00010000||0.00010000


 16%|█▌        | 802/5000 [02:22<11:17,  6.19it/s]

step: 800 	g_step: 801 	loss: 0.47348318||1.87765431 	lr: 0.00010000||0.00010000


 18%|█▊        | 902/5000 [02:38<11:41,  5.84it/s]

step: 900 	g_step: 901 	loss: 0.28019845||2.52113986 	lr: 0.00010000||0.00010000


 20%|██        | 1002/5000 [02:55<11:14,  5.93it/s]

step: 1000 	g_step: 1001 	loss: 0.76690930||2.72170901 	lr: 0.00010000||0.00010000


 22%|██▏       | 1102/5000 [03:12<10:56,  5.94it/s]

step: 1100 	g_step: 1101 	loss: 0.71308535||1.78225064 	lr: 0.00010000||0.00010000


 24%|██▍       | 1202/5000 [03:28<10:28,  6.04it/s]

step: 1200 	g_step: 1201 	loss: 0.50230998||2.00043058 	lr: 0.00010000||0.00010000


 26%|██▌       | 1302/5000 [03:45<10:07,  6.08it/s]

step: 1300 	g_step: 1301 	loss: 0.28285471||3.84568262 	lr: 0.00010000||0.00010000


 28%|██▊       | 1402/5000 [04:01<09:48,  6.11it/s]

step: 1400 	g_step: 1401 	loss: 0.14326595||1.86456335 	lr: 0.00010000||0.00010000


 30%|███       | 1502/5000 [04:18<09:30,  6.13it/s]

step: 1500 	g_step: 1501 	loss: 0.26392862||4.41303587 	lr: 0.00010000||0.00010000


 32%|███▏      | 1602/5000 [04:35<09:10,  6.18it/s]

step: 1600 	g_step: 1601 	loss: 0.34400505||2.59716797 	lr: 0.00010000||0.00010000


 34%|███▍      | 1702/5000 [04:51<08:53,  6.18it/s]

step: 1700 	g_step: 1701 	loss: 0.16148101||2.40023828 	lr: 0.00010000||0.00010000


 36%|███▌      | 1802/5000 [05:08<08:41,  6.14it/s]

step: 1800 	g_step: 1801 	loss: 0.06433099||3.94969678 	lr: 0.00010000||0.00010000


 38%|███▊      | 1902/5000 [05:24<08:45,  5.90it/s]

step: 1900 	g_step: 1901 	loss: 0.35265899||3.18422580 	lr: 0.00010000||0.00010000


 40%|████      | 2002/5000 [05:41<08:31,  5.86it/s]

step: 2000 	g_step: 2001 	loss: 0.14776574||2.83403587 	lr: 0.00009000||0.00009000


 42%|████▏     | 2102/5000 [05:57<08:05,  5.97it/s]

step: 2100 	g_step: 2101 	loss: 0.42529425||3.05529976 	lr: 0.00009000||0.00009000


 44%|████▍     | 2202/5000 [06:14<07:46,  6.00it/s]

step: 2200 	g_step: 2201 	loss: 0.04028940||1.93694723 	lr: 0.00009000||0.00009000


 46%|████▌     | 2302/5000 [06:30<07:23,  6.08it/s]

step: 2300 	g_step: 2301 	loss: 0.09160873||2.77542186 	lr: 0.00009000||0.00009000


 48%|████▊     | 2402/5000 [06:47<07:03,  6.13it/s]

step: 2400 	g_step: 2401 	loss: 0.01894629||4.59111404 	lr: 0.00009000||0.00009000


 50%|█████     | 2502/5000 [07:03<06:47,  6.13it/s]

step: 2500 	g_step: 2501 	loss: 0.17513354||2.96664429 	lr: 0.00009000||0.00009000


 52%|█████▏    | 2602/5000 [07:20<06:29,  6.16it/s]

step: 2600 	g_step: 2601 	loss: 0.10111352||1.98912680 	lr: 0.00009000||0.00009000


 54%|█████▍    | 2702/5000 [07:37<06:28,  5.91it/s]

step: 2700 	g_step: 2701 	loss: 0.01244320||3.23239398 	lr: 0.00009000||0.00009000


 56%|█████▌    | 2802/5000 [07:53<06:05,  6.01it/s]

step: 2800 	g_step: 2801 	loss: 0.11870161||3.48469782 	lr: 0.00009000||0.00009000


 58%|█████▊    | 2902/5000 [08:11<05:57,  5.87it/s]

step: 2900 	g_step: 2901 	loss: 0.01832391||3.28014040 	lr: 0.00009000||0.00009000


 60%|██████    | 3002/5000 [08:33<06:42,  4.96it/s]

step: 3000 	g_step: 3001 	loss: 0.06413088||7.11774683 	lr: 0.00009000||0.00009000


 62%|██████▏   | 3102/5000 [08:52<07:11,  4.40it/s]

step: 3100 	g_step: 3101 	loss: 0.05377886||2.98013783 	lr: 0.00009000||0.00009000


 64%|██████▍   | 3201/5000 [09:11<06:26,  4.66it/s]

step: 3200 	g_step: 3201 	loss: 0.04689698||2.69291568 	lr: 0.00009000||0.00009000


 66%|██████▌   | 3302/5000 [09:30<04:43,  6.00it/s]

step: 3300 	g_step: 3301 	loss: 0.02183354||1.55048549 	lr: 0.00009000||0.00009000


 68%|██████▊   | 3402/5000 [09:47<04:18,  6.18it/s]

step: 3400 	g_step: 3401 	loss: 0.10570898||2.66172671 	lr: 0.00009000||0.00009000


 70%|███████   | 3502/5000 [10:04<04:04,  6.13it/s]

step: 3500 	g_step: 3501 	loss: 0.18161990||5.65461302 	lr: 0.00009000||0.00009000


 72%|███████▏  | 3602/5000 [10:20<03:51,  6.04it/s]

step: 3600 	g_step: 3601 	loss: 0.02309214||1.12742496 	lr: 0.00009000||0.00009000


 74%|███████▍  | 3702/5000 [10:37<03:39,  5.90it/s]

step: 3700 	g_step: 3701 	loss: 0.07360192||2.63685465 	lr: 0.00009000||0.00009000


 76%|███████▌  | 3801/5000 [10:53<03:52,  5.17it/s]

step: 3800 	g_step: 3801 	loss: 0.03743961||1.49955142 	lr: 0.00009000||0.00009000


 78%|███████▊  | 3902/5000 [11:12<03:12,  5.70it/s]

step: 3900 	g_step: 3901 	loss: 0.01164702||4.25912046 	lr: 0.00009000||0.00009000


 80%|████████  | 4002/5000 [11:30<03:16,  5.08it/s]

step: 4000 	g_step: 4001 	loss: 0.04629619||2.58767176 	lr: 0.00008100||0.00008100


 82%|████████▏ | 4101/5000 [11:50<03:27,  4.33it/s]

step: 4100 	g_step: 4101 	loss: 0.01222607||2.92262340 	lr: 0.00008100||0.00008100


 84%|████████▍ | 4201/5000 [12:10<02:25,  5.50it/s]

step: 4200 	g_step: 4201 	loss: 0.01224753||5.73890638 	lr: 0.00008100||0.00008100


 86%|████████▌ | 4302/5000 [12:30<01:57,  5.93it/s]

step: 4300 	g_step: 4301 	loss: 0.02066518||3.18055010 	lr: 0.00008100||0.00008100


 88%|████████▊ | 4402/5000 [12:48<01:36,  6.17it/s]

step: 4400 	g_step: 4401 	loss: 0.00932878||3.07053518 	lr: 0.00008100||0.00008100


 90%|█████████ | 4502/5000 [13:06<01:22,  6.04it/s]

step: 4500 	g_step: 4501 	loss: 0.01239224||3.75556636 	lr: 0.00008100||0.00008100


 92%|█████████▏| 4602/5000 [13:23<01:04,  6.17it/s]

step: 4600 	g_step: 4601 	loss: 0.01578561||1.28752124 	lr: 0.00008100||0.00008100


 94%|█████████▍| 4702/5000 [13:40<00:48,  6.10it/s]

step: 4700 	g_step: 4701 	loss: 0.11321343||2.44210649 	lr: 0.00008100||0.00008100


 96%|█████████▌| 4802/5000 [13:56<00:31,  6.23it/s]

step: 4800 	g_step: 4801 	loss: 0.01473635||4.33526802 	lr: 0.00008100||0.00008100


 98%|█████████▊| 4902/5000 [14:13<00:16,  5.88it/s]

step: 4900 	g_step: 4901 	loss: 0.10136785||2.70155931 	lr: 0.00008100||0.00008100


100%|██████████| 5000/5000 [14:29<00:00,  2.95it/s]

All Done! test_loss is 2.3594472408294678



