In [10]:
"""
FizzBuzz is the following problem:

For each of the numbers 1 to 100:
* if the number is divisible by 3, print "fizz"
* if the number is divisible by 5, print "buzz"
* if the number is divisible by 15, print "fizzbuzz"
* otherwise, just print the number
"""
from typing import List

import jax.numpy as np
import jax.random as jxr
import numpy as onp
from tqdm.autonotebook import tqdm

from colin_net.layers import Linear, Tanh, Softmax
from colin_net.nn import NeuralNet
from colin_net.data import BatchIterator
from colin_net.train import train
from colin_net.loss import mean_sqaured_error


key = jxr.PRNGKey(42)

def fizz_buzz_encode(x: int) -> List[int]:
    if x % 15 == 0:
        return [0, 0, 0, 1]
    elif x % 5 == 0:
        return [0, 0, 1, 0]
    elif x % 3 == 0:
        return [0, 1, 0, 0]
    else:
        return [1, 0, 0, 0]


def binary_encode(x: int) -> List[int]:
    """
    10 digit binary encoding of x
    """
    return [x >> i & 1 for i in range(10)]

inputs = np.array([
    binary_encode(x)
    for x in range(101, 1024)
])

targets = np.array([
    fizz_buzz_encode(x)
    for x in range(101, 1024)
])

net = NeuralNet([
    Linear.initialize(input_size=10, output_size=50, key=key),
    Tanh(),
    Linear.initialize(input_size=50, output_size=4, key=key),
    Softmax()
])

iterator = BatchIterator(inputs=inputs, targets=targets)


# define accuracy calculation
def accuracy(actual, predicted):
    return np.mean(np.argmax(actual, axis=1) == np.argmax(predicted, axis=1))



num_epochs = 5000

progress = train(net,
      loss=mean_sqaured_error,
      iterator=iterator,
      num_epochs=num_epochs,
      lr=0.001)


points = []
for i, (epoch, loss, net) in enumerate(tqdm(progress, total=num_epochs)):
    
    # check loss and accuracy every 5 epochs
    if i % 5 == 0:
        print(epoch, loss)
        predicted = net(inputs)

        if accuracy(targets, predicted) >= 0.99:
            print("Achieved Perfect Prediction!")
            points.append([epoch, loss])
            break
    points.append([epoch, loss])


for x in range(1, 101):
    predicted = net.predict(np.array(binary_encode(x)))
    predicted_idx = np.argmax(predicted)
    actual_idx = np.argmax(fizz_buzz_encode(x))
    labels = [str(x), "fizz", "buzz", "fizzbuzz"]
    print(x, labels[predicted_idx], labels[actual_idx])


HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

0 9.955859
5 9.943995
10 9.931906
15 9.919583
20 9.907022
25 9.89421
30 9.881131
35 9.867778
40 9.854136
45 9.8401985
50 9.825954
55 9.811393
60 9.796512
65 9.781302
70 9.765763
75 9.749893
80 9.733696
85 9.7171755
90 9.700343
95 9.683205
100 9.665772
105 9.648065
110 9.630095
115 9.611877
120 9.593427
125 9.574757
130 9.5558815
135 9.536812
140 9.517549
145 9.498098
150 9.478463
155 9.458632
160 9.438606
165 9.418368
170 9.39791
175 9.3772135
180 9.356264
185 9.335031
190 9.313507
195 9.291674
200 9.269509
205 9.247008
210 9.224153
215 9.20093
220 9.1773405
225 9.153376
230 9.129044
235 9.104341
240 9.079286
245 9.05389
250 9.028173
255 9.002151
260 8.975837
265 8.94927
270 8.922459
275 8.895438
280 8.868224
285 8.840851
290 8.813335
295 8.785709
300 8.757987
305 8.730195
310 8.702343
315 8.674469
320 8.646584
325 8.618709
330 8.590871
335 8.56307
340 8.53535
345 8.507713
350 8.4802
355 8.452824
360 8.425627
365 8.398628
370 8.371863
375 8.345352
380 8.3191395
385 8.293239
390 8.26768

KeyboardInterrupt: 

In [17]:
accuracy(targets, net(inputs))

DeviceArray(0.42470205, dtype=float32)