In [2]:
import numpy as np
import pandas as pd
import torch

# create training data

We will encode integer number to 10 digits of binary number.

In [3]:
# Number of binary digits
NUM_DIGITS = 10

# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"] => [0, 1, 2, 3]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0

# decode the encoded output to the original input: [0, 1, 2, 3] => [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

In [4]:
X = pd.DataFrame([binary_encode(i, NUM_DIGITS) for i in range(1, 2 ** NUM_DIGITS)])
y = pd.DataFrame([fizz_buzz_encode(i) for i in range(1, 2 ** NUM_DIGITS)])

In [5]:
X.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,1,0,0,0,0,0,0,0,0,0
1,0,1,0,0,0,0,0,0,0,0
2,1,1,0,0,0,0,0,0,0,0
3,0,0,1,0,0,0,0,0,0,0
4,1,0,1,0,0,0,0,0,0,0


In [6]:
y.head()

Unnamed: 0,0
0,0
1,0
2,1
3,0
4,2


# define model

Define a single layer perceptron with 100 hidden unit. The input layer and the output layer will have 10 and 4 units to match our input and output size.

In [7]:
NUM_HIDDEN = 100

# Define the model
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

# train

Train the model using examples ranging from 101 to 1023. We will split the data into train and validation set. We will shuffle the train data and train with batch size of 128.

In [8]:
from sklearn.model_selection import train_test_split

In [22]:
sum([2**i * int(_) for i,_ in enumerate(X[100])])

101

In [23]:
def split_train_test(X, y, test_size):
    idx = np.arange(len(X))
    train_idx, val_idx = train_test_split(idx, test_size=0.1)
    trainX, trainY = X[train_idx], y[train_idx]
    valX, valY = X[val_idx], y[val_idx]
    return (trainX, trainY), (valX, valY)

In [24]:
BATCH_SIZE = 128

if not isinstance(X, torch.Tensor):
    X = torch.tensor(X.values, dtype=torch.float)
    y = torch.tensor(y.values[:, 0], dtype=torch.long)

# Start training it
for epoch in range(10000):
    # create validation set by randomly select 10% of train data
    # the train set is also shuffled as we split the data.
    (trainX, trainY), (valX, valY) = split_train_test(X[100:], y[100:], test_size=0.1)
    
    for start in range(0, len(trainX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trainX[start:end]
        batchY = trainY[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Find loss on training data
    loss = loss_fn(model(trainX), trainY).item()
    
    # Find accuracy on validation data
    acc = (model(valX).argmax(1) == valY).sum().item() / len(valX)
    
    print(f'Epoch: {epoch}, Loss: {loss:.4f}, Acc: {acc:.2f}')

Epoch: 0, Loss: 0.0081, Acc: 1.00
Epoch: 1, Loss: 0.0080, Acc: 1.00
Epoch: 2, Loss: 0.0082, Acc: 1.00
Epoch: 3, Loss: 0.0080, Acc: 1.00
Epoch: 4, Loss: 0.0079, Acc: 1.00
Epoch: 5, Loss: 0.0077, Acc: 1.00
Epoch: 6, Loss: 0.0078, Acc: 1.00
Epoch: 7, Loss: 0.0078, Acc: 1.00
Epoch: 8, Loss: 0.0080, Acc: 1.00
Epoch: 9, Loss: 0.0079, Acc: 1.00
Epoch: 10, Loss: 0.0077, Acc: 1.00
Epoch: 11, Loss: 0.0078, Acc: 1.00
Epoch: 12, Loss: 0.0082, Acc: 1.00
Epoch: 13, Loss: 0.0078, Acc: 1.00
Epoch: 14, Loss: 0.0080, Acc: 1.00
Epoch: 15, Loss: 0.0079, Acc: 1.00
Epoch: 16, Loss: 0.0076, Acc: 1.00
Epoch: 17, Loss: 0.0081, Acc: 1.00
Epoch: 18, Loss: 0.0078, Acc: 1.00
Epoch: 19, Loss: 0.0080, Acc: 1.00
Epoch: 20, Loss: 0.0082, Acc: 1.00
Epoch: 21, Loss: 0.0078, Acc: 1.00
Epoch: 22, Loss: 0.0081, Acc: 1.00
Epoch: 23, Loss: 0.0078, Acc: 1.00
Epoch: 24, Loss: 0.0082, Acc: 1.00
Epoch: 25, Loss: 0.0079, Acc: 1.00
Epoch: 26, Loss: 0.0083, Acc: 1.00
Epoch: 27, Loss: 0.0081, Acc: 1.00
Epoch: 28, Loss: 0.0075, Acc: 

Epoch: 253, Loss: 0.0073, Acc: 1.00
Epoch: 254, Loss: 0.0076, Acc: 1.00
Epoch: 255, Loss: 0.0079, Acc: 1.00
Epoch: 256, Loss: 0.0075, Acc: 1.00
Epoch: 257, Loss: 0.0075, Acc: 1.00
Epoch: 258, Loss: 0.0076, Acc: 1.00
Epoch: 259, Loss: 0.0079, Acc: 1.00
Epoch: 260, Loss: 0.0075, Acc: 1.00
Epoch: 261, Loss: 0.0078, Acc: 1.00
Epoch: 262, Loss: 0.0069, Acc: 1.00
Epoch: 263, Loss: 0.0075, Acc: 1.00
Epoch: 264, Loss: 0.0077, Acc: 1.00
Epoch: 265, Loss: 0.0078, Acc: 1.00
Epoch: 266, Loss: 0.0075, Acc: 1.00
Epoch: 267, Loss: 0.0076, Acc: 1.00
Epoch: 268, Loss: 0.0078, Acc: 1.00
Epoch: 269, Loss: 0.0079, Acc: 1.00
Epoch: 270, Loss: 0.0073, Acc: 1.00
Epoch: 271, Loss: 0.0072, Acc: 1.00
Epoch: 272, Loss: 0.0073, Acc: 1.00
Epoch: 273, Loss: 0.0079, Acc: 1.00
Epoch: 274, Loss: 0.0075, Acc: 1.00
Epoch: 275, Loss: 0.0077, Acc: 1.00
Epoch: 276, Loss: 0.0075, Acc: 1.00
Epoch: 277, Loss: 0.0077, Acc: 1.00
Epoch: 278, Loss: 0.0079, Acc: 1.00
Epoch: 279, Loss: 0.0071, Acc: 1.00
Epoch: 280, Loss: 0.0075, Ac

Epoch: 500, Loss: 0.0072, Acc: 1.00
Epoch: 501, Loss: 0.0075, Acc: 1.00
Epoch: 502, Loss: 0.0075, Acc: 1.00
Epoch: 503, Loss: 0.0076, Acc: 1.00
Epoch: 504, Loss: 0.0072, Acc: 1.00
Epoch: 505, Loss: 0.0070, Acc: 1.00
Epoch: 506, Loss: 0.0069, Acc: 1.00
Epoch: 507, Loss: 0.0075, Acc: 1.00
Epoch: 508, Loss: 0.0073, Acc: 1.00
Epoch: 509, Loss: 0.0076, Acc: 1.00
Epoch: 510, Loss: 0.0073, Acc: 1.00
Epoch: 511, Loss: 0.0073, Acc: 1.00
Epoch: 512, Loss: 0.0073, Acc: 1.00
Epoch: 513, Loss: 0.0070, Acc: 1.00
Epoch: 514, Loss: 0.0073, Acc: 1.00
Epoch: 515, Loss: 0.0074, Acc: 1.00
Epoch: 516, Loss: 0.0074, Acc: 1.00
Epoch: 517, Loss: 0.0074, Acc: 1.00
Epoch: 518, Loss: 0.0075, Acc: 1.00
Epoch: 519, Loss: 0.0075, Acc: 1.00
Epoch: 520, Loss: 0.0073, Acc: 1.00
Epoch: 521, Loss: 0.0073, Acc: 1.00
Epoch: 522, Loss: 0.0075, Acc: 1.00
Epoch: 523, Loss: 0.0073, Acc: 1.00
Epoch: 524, Loss: 0.0073, Acc: 1.00
Epoch: 525, Loss: 0.0070, Acc: 1.00
Epoch: 526, Loss: 0.0076, Acc: 1.00
Epoch: 527, Loss: 0.0071, Ac

Epoch: 753, Loss: 0.0072, Acc: 1.00
Epoch: 754, Loss: 0.0071, Acc: 1.00
Epoch: 755, Loss: 0.0071, Acc: 1.00
Epoch: 756, Loss: 0.0068, Acc: 1.00
Epoch: 757, Loss: 0.0069, Acc: 1.00
Epoch: 758, Loss: 0.0069, Acc: 1.00
Epoch: 759, Loss: 0.0072, Acc: 1.00
Epoch: 760, Loss: 0.0069, Acc: 1.00
Epoch: 761, Loss: 0.0069, Acc: 1.00
Epoch: 762, Loss: 0.0069, Acc: 1.00
Epoch: 763, Loss: 0.0070, Acc: 1.00
Epoch: 764, Loss: 0.0068, Acc: 1.00
Epoch: 765, Loss: 0.0070, Acc: 1.00
Epoch: 766, Loss: 0.0071, Acc: 1.00
Epoch: 767, Loss: 0.0069, Acc: 1.00
Epoch: 768, Loss: 0.0071, Acc: 1.00
Epoch: 769, Loss: 0.0069, Acc: 1.00
Epoch: 770, Loss: 0.0069, Acc: 1.00
Epoch: 771, Loss: 0.0072, Acc: 1.00
Epoch: 772, Loss: 0.0068, Acc: 1.00
Epoch: 773, Loss: 0.0069, Acc: 1.00
Epoch: 774, Loss: 0.0070, Acc: 1.00
Epoch: 775, Loss: 0.0070, Acc: 1.00
Epoch: 776, Loss: 0.0071, Acc: 1.00
Epoch: 777, Loss: 0.0072, Acc: 1.00
Epoch: 778, Loss: 0.0071, Acc: 1.00
Epoch: 779, Loss: 0.0071, Acc: 1.00
Epoch: 780, Loss: 0.0070, Ac

Epoch: 984, Loss: 0.0067, Acc: 1.00
Epoch: 985, Loss: 0.0068, Acc: 1.00
Epoch: 986, Loss: 0.0067, Acc: 1.00
Epoch: 987, Loss: 0.0070, Acc: 1.00
Epoch: 988, Loss: 0.0068, Acc: 1.00
Epoch: 989, Loss: 0.0071, Acc: 1.00
Epoch: 990, Loss: 0.0067, Acc: 1.00
Epoch: 991, Loss: 0.0065, Acc: 1.00
Epoch: 992, Loss: 0.0068, Acc: 1.00
Epoch: 993, Loss: 0.0068, Acc: 1.00
Epoch: 994, Loss: 0.0065, Acc: 1.00
Epoch: 995, Loss: 0.0068, Acc: 1.00
Epoch: 996, Loss: 0.0069, Acc: 1.00
Epoch: 997, Loss: 0.0068, Acc: 1.00
Epoch: 998, Loss: 0.0068, Acc: 1.00
Epoch: 999, Loss: 0.0065, Acc: 1.00
Epoch: 1000, Loss: 0.0065, Acc: 1.00
Epoch: 1001, Loss: 0.0069, Acc: 1.00
Epoch: 1002, Loss: 0.0070, Acc: 1.00
Epoch: 1003, Loss: 0.0071, Acc: 1.00
Epoch: 1004, Loss: 0.0067, Acc: 1.00
Epoch: 1005, Loss: 0.0069, Acc: 1.00
Epoch: 1006, Loss: 0.0067, Acc: 1.00
Epoch: 1007, Loss: 0.0068, Acc: 1.00
Epoch: 1008, Loss: 0.0068, Acc: 1.00
Epoch: 1009, Loss: 0.0067, Acc: 1.00
Epoch: 1010, Loss: 0.0068, Acc: 1.00
Epoch: 1011, Loss

Epoch: 1215, Loss: 0.0066, Acc: 1.00
Epoch: 1216, Loss: 0.0066, Acc: 1.00
Epoch: 1217, Loss: 0.0063, Acc: 1.00
Epoch: 1218, Loss: 0.0065, Acc: 1.00
Epoch: 1219, Loss: 0.0068, Acc: 1.00
Epoch: 1220, Loss: 0.0064, Acc: 1.00
Epoch: 1221, Loss: 0.0066, Acc: 1.00
Epoch: 1222, Loss: 0.0063, Acc: 1.00
Epoch: 1223, Loss: 0.0061, Acc: 1.00
Epoch: 1224, Loss: 0.0066, Acc: 1.00
Epoch: 1225, Loss: 0.0068, Acc: 1.00
Epoch: 1226, Loss: 0.0065, Acc: 1.00
Epoch: 1227, Loss: 0.0063, Acc: 1.00
Epoch: 1228, Loss: 0.0063, Acc: 1.00
Epoch: 1229, Loss: 0.0065, Acc: 1.00
Epoch: 1230, Loss: 0.0067, Acc: 1.00
Epoch: 1231, Loss: 0.0066, Acc: 1.00
Epoch: 1232, Loss: 0.0061, Acc: 1.00
Epoch: 1233, Loss: 0.0068, Acc: 1.00
Epoch: 1234, Loss: 0.0066, Acc: 1.00
Epoch: 1235, Loss: 0.0066, Acc: 1.00
Epoch: 1236, Loss: 0.0065, Acc: 1.00
Epoch: 1237, Loss: 0.0065, Acc: 1.00
Epoch: 1238, Loss: 0.0063, Acc: 1.00
Epoch: 1239, Loss: 0.0066, Acc: 1.00
Epoch: 1240, Loss: 0.0066, Acc: 1.00
Epoch: 1241, Loss: 0.0061, Acc: 1.00
E

Epoch: 1461, Loss: 0.0063, Acc: 1.00
Epoch: 1462, Loss: 0.0063, Acc: 1.00
Epoch: 1463, Loss: 0.0063, Acc: 1.00
Epoch: 1464, Loss: 0.0061, Acc: 1.00
Epoch: 1465, Loss: 0.0063, Acc: 1.00
Epoch: 1466, Loss: 0.0063, Acc: 1.00
Epoch: 1467, Loss: 0.0061, Acc: 1.00
Epoch: 1468, Loss: 0.0064, Acc: 1.00
Epoch: 1469, Loss: 0.0063, Acc: 1.00
Epoch: 1470, Loss: 0.0063, Acc: 1.00
Epoch: 1471, Loss: 0.0065, Acc: 1.00
Epoch: 1472, Loss: 0.0064, Acc: 1.00
Epoch: 1473, Loss: 0.0062, Acc: 1.00
Epoch: 1474, Loss: 0.0061, Acc: 1.00
Epoch: 1475, Loss: 0.0063, Acc: 1.00
Epoch: 1476, Loss: 0.0065, Acc: 1.00
Epoch: 1477, Loss: 0.0063, Acc: 1.00
Epoch: 1478, Loss: 0.0062, Acc: 1.00
Epoch: 1479, Loss: 0.0061, Acc: 1.00
Epoch: 1480, Loss: 0.0062, Acc: 1.00
Epoch: 1481, Loss: 0.0063, Acc: 1.00
Epoch: 1482, Loss: 0.0061, Acc: 1.00
Epoch: 1483, Loss: 0.0063, Acc: 1.00
Epoch: 1484, Loss: 0.0064, Acc: 1.00
Epoch: 1485, Loss: 0.0062, Acc: 1.00
Epoch: 1486, Loss: 0.0065, Acc: 1.00
Epoch: 1487, Loss: 0.0064, Acc: 1.00
E

Epoch: 1708, Loss: 0.0059, Acc: 1.00
Epoch: 1709, Loss: 0.0063, Acc: 1.00
Epoch: 1710, Loss: 0.0060, Acc: 1.00
Epoch: 1711, Loss: 0.0061, Acc: 1.00
Epoch: 1712, Loss: 0.0058, Acc: 1.00
Epoch: 1713, Loss: 0.0061, Acc: 1.00
Epoch: 1714, Loss: 0.0062, Acc: 1.00
Epoch: 1715, Loss: 0.0059, Acc: 1.00
Epoch: 1716, Loss: 0.0062, Acc: 1.00
Epoch: 1717, Loss: 0.0061, Acc: 1.00
Epoch: 1718, Loss: 0.0060, Acc: 1.00
Epoch: 1719, Loss: 0.0061, Acc: 1.00
Epoch: 1720, Loss: 0.0061, Acc: 1.00
Epoch: 1721, Loss: 0.0061, Acc: 1.00
Epoch: 1722, Loss: 0.0061, Acc: 1.00
Epoch: 1723, Loss: 0.0061, Acc: 1.00
Epoch: 1724, Loss: 0.0060, Acc: 1.00
Epoch: 1725, Loss: 0.0060, Acc: 1.00


KeyboardInterrupt: 

# Evaluate in test data

In [27]:
testX, testY = X[:100], y[:100]

In [28]:
y_pred = model(testX).argmax(1)

In [29]:
predictions = zip(range(1, 101), y_pred.tolist())

In [31]:
count = 0
for i in range(100):
    correct = y_pred[i] == testY[i]
    output = fizz_buzz_decode(i + 1, y_pred[i])
    if correct:
        print(output)
    else:
        print(f"\x1b[31m{output}\x1b[0m")
    if not correct: count += 1
print(f'{count} out of 100 wrong')

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz
[31m21[0m
22
23
fizz
[31m25[0m
26
fizz
28
29
fizzbuzz
31
32
fizz
34
buzz
fizz
37
38
fizz
buzz
41
fizz
43
44
fizzbuzz
46
47
fizz
49
buzz
fizz
52
53
fizz
buzz
56
fizz
58
59
fizzbuzz
61
62
fizz
64
buzz
fizz
67
68
fizz
buzz
71
fizz
73
74
fizzbuzz
76
77
fizz
79
buzz
[31m81[0m
82
83
[31m84[0m
buzz
86
fizz
88
89
fizzbuzz
91
92
fizz
94
buzz
fizz
97
98
fizz
buzz
4 out of 100 wrong
