# Data Science and Artificial Intelligence Practice Homework
DSAI HW3-Substractor

## Prerequisite
- Python 3.6.4

## Install Dependency
```sh
$ pip install -r requirements.txt
```

## Usage
```sh
$ python main.py [-o OPTION] [-d DATA] [-m MODEL]
```

| General Options            | Description                                    |
| ---                        | ---                                            |
| -h, --help                 | show this help message and exit                |
| ---                        | ---                                            |
| Advance Options            | Description                                    |
| -o gen                     | data generation                                |
| -o train                   | training model                                 |
| -o report\_training\_data  | show all training data                         |
| -o report\_validation\_data| show all validation data                       |
| -o report\_testing\_data   | show all testing data                          |
| -o report\_accuracy        | show accuracy                                  |
| -o test                    | input formula by self                          |
| -d DATA                    | input the path of training (or generation) data|
|                            | (default: src/data.pkl)                        |
| -m MODEL                   | input the path of model                        |
|                            | (default: src/my\_model.h5)                    |

## Architecture
### Data
- Training Data: 18,000
- Validation Data: 2,000
- Testing Data: 60,000

### Model
![model](img/seq2seq.png)

- Using sequence to sequence model
- Encoder: bi-directional LSTM (Hidden Size = 256)
- Decoder: LSTM (Hidden Size = 512)

| Layer (type)                    | Output Shape        | Param #    | Connected to                     |
| ---                             | ---                 | ---        | ---                              |
| input\_1 (InputLayer)           | (None, 7, 12)       | 0          |                                  |
| ---                             | ---                 | ---        | ---                              |
| bidirectional\_1 (Bidirectional)| \[(None, 512), ...  | 550912     | input\_1[0][0]                   |
| ---                             | ---                 | ---        | ---                              |
| reshape\_1 (Reshape)            | (None, 1, 512)      | 0          | bidirectional\_1[0][0]           |
| ---                             | ---                 | ---        | ---                              |
| concatenate\_1 (Concatenate)    | (None, 512)         | 0          | bidirectional\_1[0][1]           |
|                                 |                     |            | bidirectional\_1[0][3]           |
| ---                             | ---                 | ---        | ---                              |
| concatenate\_2 (Concatenate)    | (None, 512)         | 0          | bidirectional\_1[0][2]           |
|                                 |                     |            | bidirectional\_1[0][4]           |
| ---                             | ---                 | ---        | ---                              |
| lstm\_2 (LSTM)                  | \[(None, 512), ...  | 2099200    | reshape\_1[0][0]                 |
|                                 |                     |            | concatenate\_1[0][0]             |
|                                 |                     |            | concatenate\_2[0][0]             |
|                                 |                     |            | reshape\_1[0][0]                 |
|                                 |                     |            | lstm\_2[0][1]                    |
|                                 |                     |            | lstm\_2[0][2]                    |
|                                 |                     |            | reshape\_1[0][0]                 |
|                                 |                     |            | lstm\_2[1][1]                    |
|                                 |                     |            | lstm\_2[1][2]                    |
|                                 |                     |            | reshape\_1[0][0]                 |
|                                 |                     |            | lstm\_2[2][1]                    |
|                                 |                     |            | lstm\_2[2][2]                    |
| ---                             | ---                 | ---        | ---                              |
| dense\_1 (Dense)                | (None, 12)          | 6156       | lstm\_2[0][0]                    |
| ---                             | ---                 | ---        | ---                              |
| dense\_2 (Dense)                | (None, 12)          | 6156       | lstm\_2[1][0]                    |
| ---                             | ---                 | ---        | ---                              |
| dense\_3 (Dense)                | (None, 12)          | 6156       | lstm\_2[2][0]                    |
| ---                             | ---                 | ---        | ---                              |
| dense\_4 (Dense)                | (None, 12)          | 6156       | lstm\_2[3][0]                    |
| ---                             | ---                 | ---        | ---                              |
| concatenate\_3 (Concatenate)    | (None, 48)          | 0          | dense\_1[0][0]                   |
|                                 |                     |            | dense\_2[0][0]                   |
|                                 |                     |            | dense\_3[0][0]                   |
|                                 |                     |            | dense\_4[0][0]                   |
| ---                             | ---                 | ---        | ---                              |
| reshape\_2 (Reshape)            | (None, 4, 12)       | 0          | concatenate\_3[0][0]             |
| ---                             | ---                 | ---        | ---                              |
| Total params: 2,674,736         |                     |            |                                  |
| Trainable params: 2,674,736     |                     |            |                                  |
| Non-trainable params: 0         |                     |            |                                  |

### Result
| Iteration | Training - Loss | Training - Accuracy | Validation - Loss | Validation - Accuracy |
| ---       | ---             | ---                 | ---               | ---                   |
| 1         | 1.5882          | 0.4285              | 1.4052            | 0.4711                |
| 7         | 0.2379          | 0.9257              | 0.2542            | 0.9132                |
| 9         | 0.1509          | 0.9541              | 0.1500            | 0.9561                |
| 15        | 0.0394          | 0.9921              | 0.0726            | 0.9773                |
| 31        | 0.0056          | 0.9998              | 0.0305            | 0.9903                |
| 100       | 2.9910e-04      | 1.0000              | 0.0251            | 0.9920                |

## Related Link
- [nbviewer](https://nbviewer.jupyter.org/github/yutongshen/DSAI-HW2-BooleanSearch/blob/master/main.ipynb)

## Authors
[Yu-Tong Shen](https://github.com/yutongshen/)



In [1]:
import numpy as np
import random

# Parameters Config

In [2]:
TRAINING_SIZE = 80000
DIGITS = 3
ANS_DIGITS = DIGITS + 1
MAXLEN = DIGITS + 1 + DIGITS
chars = '0123456789- '

# Data Generation

In [3]:
questions = []
expected = []
seen = set()
print('Generating data...')
while len(questions) < TRAINING_SIZE:
    f = lambda: random.choice(range(10 ** random.choice(range(1, DIGITS + 1))))
    b, a = sorted((f(), f()))
    key = tuple((a, b))
    if key in seen:
        continue
    seen.add(key)
    query = '{}-{}'.format(a, b).ljust(MAXLEN)
    ans = str(a - b).ljust(ANS_DIGITS)
    questions.append(query)
    expected.append(ans)
print('Total addition questions:', len(questions))

Generating data...
Total addition questions: 80000


In [4]:
for i in range(10):
    print(questions[i], '=', expected[i])

808-376 = 432 
6-0     = 6   
6-1     = 5   
328-6   = 322 
689-17  = 672 
96-80   = 16  
585-425 = 160 
43-3    = 40  
374-25  = 349 
37-7    = 30  


# Processing
- Size of training data:   64,000
- Size of validation data: 16,000

In [5]:
class CharacterTable:
    def __init__(self, chars):
        self.chars  = list(chars)
        self.len    = len(chars)
        self.encode = {}
        for i, key in enumerate(self.chars):
            self.encode[key] = np.zeros(self.len, np.float32)
            self.encode[key][i] = 1.
            
    def encoder(self, C):
        result = np.zeros((len(C), self.len))
        for i, c in enumerate(C):
            try:
                result[i] = self.encode[c]
            except:
                pass
        return result
            
    def decoder(self, x):
        x = x.argmax(axis=-1)
        return ''.join(self.chars[i] for i in x)

In [6]:
ct = CharacterTable(chars)

x = np.zeros((len(questions), MAXLEN, len(chars)), np.float32)
y = np.zeros((len(expected), ANS_DIGITS, len(chars)), np.float32)
for i, sentence in enumerate(questions):
    x[i] = ct.encoder(sentence)
for i, sentence in enumerate(expected):
    y[i] = ct.encoder(sentence)

train_x = x[:18000]
train_y = y[:18000]

validation_x = x[18000:20000]
validation_y = y[18000:20000]

# Model

![model](img/seq2seq.png)

- Using sequence to sequence model
- Encoder: bi-directional LSTM
- Decoder: LSTM

In [7]:
import keras as K
from keras.models import Sequential, Model
from keras.layers.core import Dense, Activation, Lambda
from keras.layers import Input, LSTM, TimeDistributed, RepeatVector, Reshape, Dropout, Bidirectional, Concatenate
from keras.layers.normalization import BatchNormalization
from keras.models import load_model

HIDDEN_SIZE = 256
OUTPUT_TOKEN = 10

model = Sequential()

encoder_inputs = Input(shape=(MAXLEN, len(chars)))
encoder = Bidirectional(LSTM(HIDDEN_SIZE, return_state=True))
encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)

state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])
states = [state_h, state_c]

# Set up the decoder, which will only process one timestep at a time.
decoder_inputs = Reshape((1, HIDDEN_SIZE * 2))
decoder_lstm = LSTM(HIDDEN_SIZE * 2, return_state=True)

all_outputs = []
inputs = decoder_inputs(encoder_outputs)

first_decoder = True
for _ in range(ANS_DIGITS):
    # Run the decoder on one timestep
    outputs, state_h, state_c = decoder_lstm(inputs,
                                             initial_state=states)
    
    # Reinject the outputs as inputs for the next loop iteration
    # as well as update the states
    states = [state_h, state_c]
    
    # Store the current prediction (we will concatenate all predictions later)
    outputs = Dense(len(chars), activation='softmax')(outputs)
    all_outputs.append(outputs)

# Concatenate all predictions
decoder_outputs = Concatenate()(all_outputs)
decoder_outputs = Reshape((ANS_DIGITS, len(chars)))(decoder_outputs)

# Define and compile model as previously
model = Model(encoder_inputs, decoder_outputs)
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

model.summary()

batch_size = int(len(train_x) / 128 / 100) * 100

if batch_size == 0:
    batch_size = 100

model.fit(train_x, train_y, 
          batch_size=batch_size, epochs=100, 
          verbose=1, validation_data=[validation_x, validation_y])

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 7, 12)        0                                            
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) [(None, 512), (None, 550912      input_1[0][0]                    
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 1, 512)       0           bidirectional_1[0][0]            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 512)          0           bidirectional_1[0][1]            
                                                                 bidirectional_1[0][3]            
__________

Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<keras.callbacks.History at 0x7f8f0097a0b8>

# Error check

In [8]:
print('Error check')
pred = model.predict(x)
err = []
for i in range(len(x)):
    if ct.decoder(pred[i]) != ct.decoder(y[i]):
        err.append(i)
print(len(err), '/', len(x))
print()
print('Prediction'.ljust(MAXLEN + ANS_DIGITS + 3),'Ans')
print('-' * (MAXLEN + ANS_DIGITS * 2 + 4))
for i in err:
    print(ct.decoder(x[i]), '=', ct.decoder(pred[i]), ct.decoder(y[i]))


Error check
2665 / 80000

Prediction     Ans
-------------------
705-401 = 204  304 
722-717 = 3    5   
320-305 = 16   15  
296-258 = 48   38  
894-853 = 31   41  
906-828 = 68   78  
99-91   = 7    8   
524-509 = 26   15  
478-79  = 499  399 
878-803 = 76   75  
102-33  = 79   69  
773-715 = 59   58  
324-265 = 69   59  
799-105 = 684  694 
515-468 = 57   47  
315-259 = 16   56  
571-495 = 77   76  
861-751 = 100  110 
493-204 = 288  289 
992-989 = 5    3   
107-80  = 26   27  
985-927 = 59   58  
870-70  = 700  800 
691-596 = 15   95  
172-163 = 19   9   
192-152 = 50   40  
874-372 = 402  502 
798-106 = 682  692 
819-769 = 60   50  
196-148 = 58   48  
285-244 = 31   41  
927-882 = 55   45  
564-265 = 399  299 
872-809 = 62   63  
794-703 = 89   91  
660-652 = 1    8   
147-70  = 76   77  
837-755 = 72   82  
606-546 = 50   60  
863-799 = 55   64  
951-204 = 847  747 
483-481 = 1    2   
190-105 = 88   85  
404-325 = 89   79  
929-882 = 58   47  
492-299 = 293  193 
782-703 = 89   

753-746 = 8    7   
510-433 = 87   77  
685-586 = 19   99  
544-455 = 99   89  
818-759 = 69   59  
222-207 = 1    15  
138-117 = 11   21  
776-729 = 57   47  
500-189 = 312  311 
660-630 = 40   30  
920-886 = 44   34  
679-679 = 1    0   
710-699 = 1    11  
710-459 = 261  251 
310-306 = 5    4   
631-229 = 302  402 
606-102 = 404  504 
699-503 = 195  196 
823-737 = 87   86  
950-855 = 15   95  
761-722 = 49   39  
671-628 = 33   43  
980-834 = 136  146 
503-464 = 49   39  
927-914 = 14   13  
201-164 = 17   37  
131-69  = 52   62  
213-193 = 10   20  
632-590 = 43   42  
187-140 = 57   47  
312-294 = 28   18  
683-605 = 88   78  
506-460 = 45   46  
663-365 = 398  298 
628-602 = 25   26  
445-399 = 55   46  
496-299 = 297  197 
466-457 = 19   9   
292-216 = 67   76  
563-507 = 55   56  
705-699 = 5    6   
900-887 = 8    13  
660-660 = 10   0   
860-779 = 82   81  
922-913 = 1    9   
169-145 = 23   24  
148-90  = 59   58  
299-10  = 299  289 
211-203 = 18   8   
340-278 = 72   62  


647-606 = 51   41  
107-48  = 69   59  
840-759 = 12   81  
319-291 = 29   28  
751-702 = 59   49  
797-788 = 1    9   
132-117 = 1    15  
884-804 = 89   80  
608-103 = 405  505 
932-794 = 128  138 
299-221 = 77   78  
824-120 = 604  704 
818-790 = 29   28  
391-313 = 88   78  
867-865 = 1    2   
373-294 = 89   79  
138-133 = 8    5   
987-910 = 76   77  
538-380 = 148  158 
234-210 = 22   24  
772-471 = 201  301 
753-500 = 243  253 
990-904 = 88   86  
253-209 = 43   44  
480-186 = 394  294 
275-256 = 29   19  
362-314 = 49   48  
900-770 = 120  130 
145-96  = 59   49  
818-731 = 97   87  
727-681 = 56   46  
479-170 = 209  309 
741-499 = 252  242 
102-98  = 1    4   
513-504 = 1    9   
435-387 = 58   48  
532-525 = 6    7   
984-976 = 1    8   
252-198 = 55   54  
798-608 = 180  190 
575-521 = 55   54  
917-598 = 329  319 
590-306 = 285  284 
609-574 = 34   35  
546-500 = 45   46  
887-711 = 177  176 
990-490 = 400  500 
119-82  = 36   37  
446-441 = 1    5   
222-221 = 3    1   


528-461 = 66   67  
724-705 = 29   19  
975-971 = 1    4   
957-925 = 33   32  
369-364 = 1    5   
380-81  = 399  299 
747-743 = 2    4   
586-579 = 1    7   
489-417 = 81   72  
191-181 = 80   10  
923-899 = 44   24  
580-204 = 366  376 
248-188 = 50   60  
304-256 = 188  48  
706-700 = 1    6   
330-293 = 27   37  
911-780 = 121  131 
846-246 = 500  600 
937-904 = 34   33  
988-891 = 98   97  
603-554 = 59   49  
339-253 = 87   86  
825-823 = 9    2   
993-991 = 1    2   
373-370 = 93   3   
327-190 = 237  137 
119-72  = 46   47  
947-915 = 33   32  
734-689 = 55   45  
961-949 = 1    12  
891-884 = 4    7   
981-923 = 59   58  
844-814 = 20   30  
979-830 = 159  149 
293-207 = 87   86  
831-332 = 599  499 
958-874 = 94   84  
715-115 = 500  600 
159-62  = 96   97  
131-79  = 53   52  
602-556 = 36   46  
129-129 = 90   0   
810-204 = 506  606 
626-600 = 25   26  
947-348 = 699  599 
658-651 = 5    7   
841-348 = 593  493 
101-84  = 18   17  
737-703 = 24   34  
317-285 = 42   32  


183-154 = 39   29  
931-729 = 203  202 
694-688 = 1    6   
951-808 = 133  143 
130-78  = 53   52  
149-124 = 35   25  
291-205 = 87   86  
511-488 = 32   23  
871-801 = 61   70  
900-749 = 161  151 
960-882 = 88   78  
800-599 = 101  201 
674-631 = 44   43  
890-419 = 472  471 
252-214 = 48   38  
613-613 = 90   0   
919-821 = 97   98  
997-958 = 49   39  
786-291 = 595  495 
960-911 = 59   49  
593-579 = 15   14  
309-274 = 25   35  
784-738 = 56   46  
741-646 = 15   95  
110-88  = 32   22  
274-245 = 39   29  
764-706 = 59   58  
810-279 = 532  531 
397-355 = 44   42  
940-842 = 188  98  
937-853 = 94   84  
103-64  = 49   39  
893-593 = 200  300 
393-311 = 83   82  
870-309 = 662  561 
510-506 = 3    4   
703-659 = 34   44  
202-190 = 11   12  
269-169 = 900  100 
669-168 = 401  501 
815-659 = 166  156 
447-430 = 16   17  
891-833 = 59   58  
519-428 = 101  91  
847-147 = 600  700 
747-739 = 1    8   
889-810 = 87   79  
201-166 = 15   35  
398-301 = 89   97  
779-717 = 61   62  


# Test

In [9]:
q = '123-23'

q_padding = q.ljust(MAXLEN)[:MAXLEN]
test_x = ct.encoder(q_padding)
pred_y = model.predict(test_x.reshape(-1, 7, 12))
print(q, '=', ct.decoder(pred_y[0]))


123-23 = 100 
