# Making Your Model Learn Addition!

Given the string "54+7", the model should return a prediction: "61".

## 1. Import Libraries

In [1]:
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import TimeDistributed, Dense, Dropout, SimpleRNN, RepeatVector
from tensorflow.keras.callbacks import EarlyStopping, LambdaCallback

from termcolor import colored

## 2. Generate Data

In [2]:
all_chars = '0123456789+'

In [3]:
num_features = len(all_chars)
print("number of features:", num_features)
char_to_index = dict((c,i) for i,c in enumerate(all_chars))
index_to_char = dict((i,c) for i,c in enumerate(all_chars))

number of features: 11


In [4]:
def generate_data():
    first = np.random.randint(0,100)
    second = np.random.randint(0,100)
    example = str(first) + '+' + str(second)
    label = str(first+second)
    return example, label

generate_data()

('42+95', '137')

## 3. Create the Model

In [6]:
hidden_units = 128
max_time_steps = 5

model = Sequential([
    SimpleRNN(hidden_units, input_shape=(None, num_features)),
    RepeatVector(max_time_steps),
    SimpleRNN(hidden_units, return_sequences=True),
    TimeDistributed(Dense(num_features, activation='softmax'))
])

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
simple_rnn (SimpleRNN)       (None, 128)               17920     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 5, 128)            0         
_________________________________________________________________
simple_rnn_1 (SimpleRNN)     (None, 5, 128)            32896     
_________________________________________________________________
time_distributed (TimeDistri (None, 5, 11)             1419      
Total params: 52,235
Trainable params: 52,235
Non-trainable params: 0
_________________________________________________________________
