[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kamalkraj/minGPT-TF/blob/master/play_math.ipynb)

In [42]:
#This is a 3 digs example and use all gpu with 100 iterations

In [43]:
! pip install fastprogress==0.2.3

[0m--- Logging error ---
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/pip/_internal/utils/logging.py", line 177, in emit
    self.console.print(renderable, overflow="ignore", crop=False, style=style)
  File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/rich/console.py", line 1673, in print
    extend(render(renderable, render_options))
  File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/rich/console.py", line 1305, in render
    for render_output in iter_render:
  File "/usr/local/lib/python3.8/dist-packages/pip/_internal/utils/logging.py", line 134, in __rich_console__
    for line in lines:
  File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/rich/segment.py", line 249, in split_lines
    for segment in segments:
  File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/rich/console.py", line 1283, in render
    renderable = rich_cast(renderable)
  File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/rich/protocol.py", lin

In [44]:
import math
import numpy as np
import tensorflow as tf
from mingpt.model import GPT, GPTConfig

In [45]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [46]:
class AdditionDataset():
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our
    encoding will simply be the n-digit first number, n-digit second number, 
    and (n+1)-digit result, all simply concatenated together. Because each addition
    problem is so structured, there is no need to bother the model with encoding
    +, =, or other tokens. Each possible sequence has the same length, and simply
    contains the raw digits of the addition problem.
    
    As a few examples, the 2-digit problems:
    - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]
    - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]
    etc.
    
    We will also only train GPT on the final (n+1)-digits because the first
    two n-digits are always assumed to be given. So when we give GPT an exam later,
    we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like
    to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]
    in 3 sequential steps.
    
    fun exercise: does it help if the result is asked to be produced in reverse order?
    """

    def __init__(self, ndigit, split):
        self.split = split # train/test
        self.ndigit = ndigit
        self.vocab_size = 10 # 10 possible digits 0..9
        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back
        self.block_size = ndigit + ndigit + ndigit + 1 - 1
        
        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]

    def __len__(self):
        return self.ixes.size

    def __iter__(self):
        # given a problem index idx, first recover the associated a + b
        #Save the train and test data to file
        #fileName = './dataset/train_set_math_3dig.txt' if self.split == 'test' else './dataset/test_set_math_3dig.txt'
        #file = open(fileName, 'a')
        
        for idx in range(self.__len__()):
            idx = self.ixes[idx]
            nd = 10**self.ndigit
            a = idx // nd
            b = idx %  nd
            c = a + b
            #print('idx='+str(idx) +' nd= ' + str(nd) + ' a=' + str(a) + ' b=' +str(b) + ' c=' +str(c))
            #idx=9927 nd= 100 a=99 b=27 c=126
            render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028"    
            #print('render= ' + render)
            #render= 9927126
            #write the dataset to file
            #file.write(render+'\n')
            
            dix = [int(s) for s in render] # convert each character to its token index
            #print('dix= ' + str(dix))
            #dix= [9, 9, 2, 7, 1, 2, 6]
            
            # x will be input to GPT and y will be the associated expected outputs
            x = dix[:-1]
            #print('x = ' + str(x))
            #x = [9, 9, 2, 7, 1, 2]
            
            y = dix[1:] # predict the next token in the sequence
            #print('y = ' + str(y))
            #y = [9, 2, 7, 1, 2, 6]
            
            y[:self.ndigit*2-1] = [-1] * (self.ndigit*2-1) # we will only train in the output locations. -100 will mask loss to zero
            #print('y = ' + str(y))
            #y = [-1, -1, -1, 1, 2, 6]
            
            x = tf.convert_to_tensor(x,dtype=tf.int32)
            y = tf.convert_to_tensor(y,dtype=tf.int32)
            #print('tf x = ' + str(x))
            #print('tf y = ' + str(y))
            #tf x = tf.Tensor([9 9 2 7 1 2], shape=(6,), dtype=int32)
            #tf y = tf.Tensor([-1 -1 -1  1  2  6], shape=(6,), dtype=int32)
            
            yield x, y
            
        #file.close()    
    __call__ = __iter__

In [47]:
# create a dataset for e.g. 2-digit addition
ndigit = 2
train_dataset_gen = AdditionDataset(ndigit=ndigit, split='train')
test_dataset_gen = AdditionDataset(ndigit=ndigit, split='test')

In [48]:
train_dataset = tf.data.Dataset.from_generator(train_dataset_gen,(tf.int32,tf.int32))
test_dataset = tf.data.Dataset.from_generator(test_dataset_gen,(tf.int32,tf.int32))

In [49]:
# initialize a baby GPT model
mconf = GPTConfig(train_dataset_gen.vocab_size, train_dataset_gen.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
# model = GPT(mconf)

In [50]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=50, batch_size=1024, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset_gen)*(ndigit+1),
                      num_workers=4)
trainer = Trainer(GPT, mconf, train_dataset, len(train_dataset_gen), test_dataset, len(test_dataset_gen), tconf, device='GPU:1')

config.vocab_size=10,config.n_embd=128
config.block_size=6,config.n_embd=128
config.embd_pdrop=0.1
config.n_embd=128,config.n_head=4,config.attn_pdrop=0.1,config.resid_pdrop=0.1,config.n_layer=2
EncoderLayer d_model=128,num_heads=4
EncoderLayer d_model=128,num_heads=4


In [51]:
#train the first time and save checkpoints
trainer.train()

epoch 1: train loss 3.92096. lr 5.994512e-04
epoch 1: test loss 2.04901.
The gpt train model weight is saved to checkpoints:./checkpoings/ckpt_math_3dig/minigpt.ckpt


In [52]:
trainer.display_model()

Model: "gpt_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_3 (Embedding)     multiple                  1280      
                                                                 
 dropout_21 (Dropout)        multiple                  0         
                                                                 
 encoder_layer_6 (EncoderLay  multiple                 198272    
 er)                                                             
                                                                 
 encoder_layer_7 (EncoderLay  multiple                 198272    
 er)                                                             
                                                                 
 layer_normalization_19 (Lay  multiple                 256       
 erNormalization)                                                
                                                             

In [53]:
from mingpt.utils import sample

In [54]:
def give_exam(dataset, batch_size=32, max_batches=-1, printResult=False):
    
    results = []
    
    loader = dataset.batch(batch_size)
    for b, (x, y) in enumerate(loader):
        #print('b='+str(b))
        #print('x='+str(x))
        #print('y='+str(y))
        #b=0
        #x=tf.Tensor([[9 9 2 7 1 2]], shape=(1, 6), dtype=int32)
        #y=tf.Tensor([[-1 -1 -1  1  2  6]], shape=(1, 6), dtype=int32)

        d1d2 = x[:, :ndigit*2]
        #print('d1d2=' +str(d1d2))
        #d1d2=tf.Tensor([[9 9 2 7]], shape=(1, 4), dtype=int32)

        d1d2d3 = sample(trainer.model, d1d2, ndigit+1)
        #print('d1d2d3=' +str(d1d2d3))
        #d1d2d3=tf.Tensor([[9 9 2 7 1 1 1]], shape=(1, 7), dtype=int32)

        d3 = d1d2d3[:, -(ndigit+1):]
        #print('d3=' +str(d3))
        #d3=tf.Tensor([[1 1 1]], shape=(1, 3), dtype=int32)

        factors = tf.convert_to_tensor([[10**i for i in range(ndigit+1)][::-1]])
        #print('factors=' +str(factors))
        #factors=tf.Tensor([[100  10   1]], shape=(1, 3), dtype=int32)
        
        # decode the integers from individual digits
        d1i = tf.reduce_sum((d1d2[:,:ndigit] * factors[:,1:]),axis=1)
        #print('d1i='+str(d1i))
        #d1i=tf.Tensor([99], shape=(1,), dtype=int32)

        d2i = tf.reduce_sum((d1d2[:,ndigit:ndigit*2] * factors[:,1:]),axis=1)
        #print('d2i='+str(d2i))
        #d2i=tf.Tensor([27], shape=(1,), dtype=int32)
        
        d3i_pred = tf.reduce_sum((d3 * factors),axis=1)
        #print('d3i_pred='+str(d3i_pred))
        #d3i_pred=tf.Tensor([111], shape=(1,), dtype=int32)
        
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt) # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol
        for i in range(x.shape[0]):
            results.append(int(correct[i]))
            judge = 'YEP!!!' if correct[i] else 'NOPE'
            if not correct[i]:
                print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)========" 
                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))
            elif printResult:
                print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)" 
                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))                
        
        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

In [55]:
#if train already and load the weights from checkpoint
#trainer.load_checkpoints()

In [56]:
give_exam(train_dataset,batch_size=128,max_batches=-1,printResult=False)







































































final score: 2430/9000 = 27.00% correct


In [57]:
give_exam(test_dataset,batch_size=128,max_batches=-1,printResult=True)

GPT claims that 099 + 027 = 126 (gt is 126; YEP!!!)
GPT claims that 032 + 045 = 077 (gt is 077; YEP!!!)
GPT claims that 015 + 089 = 104 (gt is 104; YEP!!!)
GPT claims that 047 + 046 = 093 (gt is 093; YEP!!!)
GPT claims that 057 + 039 = 096 (gt is 096; YEP!!!)
GPT claims that 097 + 049 = 146 (gt is 146; YEP!!!)
GPT claims that 004 + 068 = 072 (gt is 072; YEP!!!)
GPT claims that 044 + 018 = 062 (gt is 062; YEP!!!)
GPT claims that 070 + 030 = 100 (gt is 100; YEP!!!)
GPT claims that 064 + 012 = 076 (gt is 076; YEP!!!)
GPT claims that 037 + 027 = 064 (gt is 064; YEP!!!)
GPT claims that 048 + 014 = 062 (gt is 062; YEP!!!)
GPT claims that 060 + 059 = 119 (gt is 119; YEP!!!)
GPT claims that 063 + 092 = 155 (gt is 155; YEP!!!)
GPT claims that 037 + 029 = 066 (gt is 066; YEP!!!)
GPT claims that 042 + 034 = 076 (gt is 076; YEP!!!)
GPT claims that 089 + 027 = 116 (gt is 116; YEP!!!)
GPT claims that 096 + 050 = 146 (gt is 146; YEP!!!)
GPT claims that 053 + 076 = 129 (gt is 129; YEP!!!)
GPT claims t

GPT claims that 033 + 036 = 069 (gt is 069; YEP!!!)
GPT claims that 022 + 083 = 105 (gt is 105; YEP!!!)
GPT claims that 049 + 075 = 124 (gt is 124; YEP!!!)
GPT claims that 051 + 053 = 104 (gt is 104; YEP!!!)
GPT claims that 090 + 072 = 162 (gt is 162; YEP!!!)
GPT claims that 057 + 076 = 133 (gt is 133; YEP!!!)
GPT claims that 099 + 067 = 166 (gt is 166; YEP!!!)
GPT claims that 057 + 079 = 136 (gt is 136; YEP!!!)
GPT claims that 099 + 061 = 160 (gt is 160; YEP!!!)
GPT claims that 094 + 072 = 166 (gt is 166; YEP!!!)
GPT claims that 021 + 028 = 049 (gt is 049; YEP!!!)
GPT claims that 093 + 073 = 166 (gt is 166; YEP!!!)
GPT claims that 092 + 075 = 167 (gt is 167; YEP!!!)
GPT claims that 052 + 095 = 147 (gt is 147; YEP!!!)
GPT claims that 065 + 067 = 132 (gt is 132; YEP!!!)
GPT claims that 073 + 028 = 101 (gt is 101; YEP!!!)
GPT claims that 017 + 050 = 067 (gt is 067; YEP!!!)
GPT claims that 089 + 048 = 137 (gt is 137; YEP!!!)
GPT claims that 079 + 067 = 146 (gt is 146; YEP!!!)
GPT claims t

GPT claims that 065 + 042 = 107 (gt is 107; YEP!!!)
GPT claims that 073 + 034 = 107 (gt is 107; YEP!!!)
GPT claims that 073 + 059 = 132 (gt is 132; YEP!!!)
GPT claims that 014 + 049 = 063 (gt is 063; YEP!!!)
GPT claims that 015 + 030 = 045 (gt is 045; YEP!!!)
GPT claims that 008 + 064 = 072 (gt is 072; YEP!!!)
GPT claims that 002 + 040 = 042 (gt is 042; YEP!!!)
GPT claims that 060 + 044 = 104 (gt is 104; YEP!!!)
GPT claims that 036 + 081 = 117 (gt is 117; YEP!!!)
GPT claims that 082 + 082 = 164 (gt is 164; YEP!!!)
GPT claims that 060 + 099 = 159 (gt is 159; YEP!!!)
GPT claims that 004 + 058 = 062 (gt is 062; YEP!!!)
GPT claims that 031 + 073 = 104 (gt is 104; YEP!!!)
GPT claims that 033 + 014 = 047 (gt is 047; YEP!!!)
GPT claims that 033 + 034 = 067 (gt is 067; YEP!!!)
GPT claims that 059 + 073 = 132 (gt is 132; YEP!!!)
GPT claims that 018 + 014 = 032 (gt is 032; YEP!!!)
GPT claims that 062 + 097 = 159 (gt is 159; YEP!!!)
GPT claims that 067 + 078 = 145 (gt is 145; YEP!!!)
GPT claims t

GPT claims that 045 + 079 = 124 (gt is 124; YEP!!!)
GPT claims that 089 + 014 = 103 (gt is 103; YEP!!!)
GPT claims that 014 + 016 = 030 (gt is 030; YEP!!!)
GPT claims that 076 + 081 = 157 (gt is 157; YEP!!!)
GPT claims that 032 + 080 = 112 (gt is 112; YEP!!!)
GPT claims that 075 + 091 = 166 (gt is 166; YEP!!!)
GPT claims that 020 + 020 = 040 (gt is 040; YEP!!!)
GPT claims that 053 + 096 = 149 (gt is 149; YEP!!!)
GPT claims that 055 + 091 = 146 (gt is 146; YEP!!!)
GPT claims that 069 + 003 = 072 (gt is 072; YEP!!!)
GPT claims that 005 + 072 = 077 (gt is 077; YEP!!!)
GPT claims that 019 + 095 = 114 (gt is 114; YEP!!!)
GPT claims that 062 + 005 = 067 (gt is 067; YEP!!!)
GPT claims that 089 + 053 = 142 (gt is 142; YEP!!!)
GPT claims that 083 + 034 = 117 (gt is 117; YEP!!!)
GPT claims that 094 + 068 = 162 (gt is 162; YEP!!!)
GPT claims that 036 + 010 = 046 (gt is 046; YEP!!!)
GPT claims that 048 + 088 = 136 (gt is 136; YEP!!!)
GPT claims that 083 + 082 = 165 (gt is 165; YEP!!!)
GPT claims t

GPT claims that 079 + 034 = 113 (gt is 113; YEP!!!)
GPT claims that 074 + 091 = 165 (gt is 165; YEP!!!)
GPT claims that 050 + 059 = 109 (gt is 109; YEP!!!)
GPT claims that 024 + 023 = 047 (gt is 047; YEP!!!)
GPT claims that 036 + 083 = 119 (gt is 119; YEP!!!)
GPT claims that 019 + 043 = 062 (gt is 062; YEP!!!)
GPT claims that 042 + 082 = 124 (gt is 124; YEP!!!)
GPT claims that 060 + 001 = 061 (gt is 061; YEP!!!)
GPT claims that 039 + 085 = 124 (gt is 124; YEP!!!)
GPT claims that 033 + 084 = 117 (gt is 117; YEP!!!)
GPT claims that 091 + 054 = 145 (gt is 145; YEP!!!)
GPT claims that 017 + 084 = 101 (gt is 101; YEP!!!)
GPT claims that 023 + 086 = 109 (gt is 109; YEP!!!)
GPT claims that 060 + 056 = 116 (gt is 116; YEP!!!)
GPT claims that 088 + 054 = 142 (gt is 142; YEP!!!)
GPT claims that 087 + 079 = 166 (gt is 166; YEP!!!)
GPT claims that 072 + 007 = 079 (gt is 079; YEP!!!)
GPT claims that 077 + 089 = 166 (gt is 166; YEP!!!)
GPT claims that 026 + 083 = 109 (gt is 109; YEP!!!)
GPT claims t

GPT claims that 087 + 059 = 146 (gt is 146; YEP!!!)
GPT claims that 005 + 017 = 022 (gt is 022; YEP!!!)
GPT claims that 032 + 070 = 102 (gt is 102; YEP!!!)
GPT claims that 045 + 057 = 102 (gt is 102; YEP!!!)
GPT claims that 076 + 003 = 079 (gt is 079; YEP!!!)
GPT claims that 021 + 045 = 066 (gt is 066; YEP!!!)
GPT claims that 041 + 086 = 127 (gt is 127; YEP!!!)
GPT claims that 055 + 042 = 097 (gt is 097; YEP!!!)
GPT claims that 093 + 001 = 094 (gt is 094; YEP!!!)
GPT claims that 034 + 035 = 069 (gt is 069; YEP!!!)
GPT claims that 095 + 027 = 122 (gt is 122; YEP!!!)
GPT claims that 050 + 096 = 146 (gt is 146; YEP!!!)
GPT claims that 057 + 009 = 066 (gt is 066; YEP!!!)
GPT claims that 036 + 013 = 049 (gt is 049; YEP!!!)
GPT claims that 089 + 017 = 106 (gt is 106; YEP!!!)
GPT claims that 086 + 043 = 129 (gt is 129; YEP!!!)
GPT claims that 052 + 052 = 104 (gt is 104; YEP!!!)
final score: 276/1000 = 27.60% correct
