Skip to content

vincent-163/transformer-arithmetic

Repository files navigation

Performing Arithmetic with Transformer

This project demonstrates the possibility of training a generic neural model to perform very complex arithmetic operations, without designing the model architecture explicitly for the task. Our model is able to compute 5-digit by 5-digit decimal multiplication at 100% accuracy. In particular, we train the GPT-2 model on a large number of generated expressions that express the process of computing multiplication step by step. Provide 87708*15192 and it should give 4192581956.

See How does it work section for explanation of the inner mechanics.

How to train

Install dependencies with pip install -r requirements.txt.

Edit train.py and edit the following lines:

  • Edit gpus=4 to the number of GPUs on your machine.
  • Edit batch_size=16 to fit the memory size of your GPU.
  • Adjust accumulate_grad_batches=8 accordingly to control the effective batch size (gpus*batch_size*accumulate_grad_batches). In my early tests I found that larger batch sizes train faster and more stable.
  • Optionally add argument resume_from_checkpoint to a checkpoint file to resume training from a previous checkpoint. PyTorch Lightning should automatically save checkpoints during training.
  • You may also want to adjust the hyperparameters in main.py. Since the task is relatively simple, a smaller model should still work, and if it does it will train faster.

Then just run python train.py and watch it training. It typically takes around 100 epoches (10000000 samples) until it converges, at which point it should reach 100% accuracy.

How to test

If you don't want to train the model and just want to test it out, you can download the pretrained model at https://drive.google.com/file/d/1YKzHTec5FsN6NftR3uc5j-SpmNpxTDsc/view?usp=sharing. Put the model (named epoch=98.ckpt) in the same folder as all other files. Otherwise, modify test.py and change the filename of the checkpoint to the one generated by train.py.

Run python test.py and give it prompts in the format of xxxxx*xxxxx;. Replace x with decimal numbers (the first digit must not be zero). For example: 12345*54321; and watch it compute the answer.

The answer is expected to be found between the last = sign and the $ sign.

How does it work

Think about how humans perform decimal multiplications:

   39
x  96
-----
  234
 351
-----
 3744

Multiplications can be computed vertically. We first decompose the multiplier into digits, and multiply each digit by the multiplicand, then we shift these results by the position of the digit and sum them together. There are two things left to do:

  • Single digit multiplication. In the above example we need to compute 39*6=234 and 39*9=351. This is again done by decomposing the multiplicand into digits, mutiply each digit by the multiplier, and sum them together. The last unit of computation, single-digit by single-digit multiplication, can be done by looking up the multiplication table.

  • While the process of summing up the inputs may seem trivial, the GPT-3 model is known to struggle with multi-operation arithmetic (see "Single Digit Three Ops" performance in the original GPT-3 paper). We need to break it up, and compute the sum of two numbers at a time. We compute the number from right to left, and compute the carry first and then the actual digit. To make the task even easier for the model, the model finds the appropriate digit first and then computes the carry digit and the ones digit, so it doesn't have to worry about where to find the digits to perform the operation. When the two numbers are of differing number of digits, the missing digits are supplied with zero. However, the operation does not stop when the digits of both numbers are used up, since the most significant digit can be carried, like 9999+101. Instead, the operation stops when the two digits and the carry are all zero. The sequence for 9999+101 looks like: (spaces added for clarity, they don't exist in actual training samples)

    9999+101;9100 9010 9111 9010 0011 000=10100
    

The full sequence for 39*96 is:

39*96;39*6;9*6;=54;3*6;=18;5+18;58030112000=23;=234;39*9;9*9;=81;3*9;=27;8+27;87050213000=35;=351;23+351;310425070303000=374;=3744$

Refer to expr.py for details regarding how the expressions are generated. Try out the generate_multiplication function for yourself (enter any two numbers as arguments).

The first 20 training samples can be found in exprs.txt. In actual training they are generated on the fly using pseudo random number generators. Here is the first line of exprs.txt:

78268*53567;78268*7;8*7;=56;6*7;=42;2*7;=14;8*7;=56;7*7;=49;5+42;52070404000=47;4+14;44080101000=18;1+56;16070505000=57;5+49;59040415000=54;=547876;78268*6;8*6;=48;6*6;=36;2*6;=12;8*6;=48;7*6;=42;4+36;46000314000=40;4+12;42060101000=16;1+48;18090404000=49;4+42;42060404000=46;=469608;78268*5;8*5;=40;6*5;=30;2*5;=10;8*5;=40;7*5;=35;4+30;40040303000=34;3+10;30030101000=13;1+40;10010404000=41;4+35;45090303000=39;=391340;78268*3;8*3;=24;6*3;=18;2*3;=6;8*3;=24;7*3;=21;2+18;28000112000=20;2+6;2608000=8;0+24;04040202000=24;2+21;21030202000=23;=234804;78268*5;8*5;=40;6*5;=30;2*5;=10;8*5;=40;7*5;=35;4+30;40040303000=34;3+10;30030101000=13;1+40;10010404000=41;4+35;45090303000=39;=391340;54787+469608;780580197603491456120415000=524395;52439+391340;900934074307210359040314000=443779;44377+234804;740170183801441943070202000=279181;27918+391340;800814059302711929010314000=419258;=4192581956$

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages