In [1]:
# http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/

import numpy as np

import torch
from torch.autograd import Variable

NUM_DIGITS = 10
NUM_HIDDEN = 100
BATCH_SIZE = 128


# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])


# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0

    
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

trX = Variable(torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)]))
trY = Variable(torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)]))


# Define the model
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)


# Start training it
for epoch in range(0, 10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Find loss on training data
    loss = loss_fn(model(trX), trY)
    if epoch % 500 == 0:
        print('Epoch: %d Loss: %f' % (epoch, loss))


# Output now
testX = Variable(torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)]))
testY = model(testX)
predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for (i, x) in predictions])

Epoch: 0 Loss: 1.180215
Epoch: 500 Loss: 0.933845
Epoch: 1000 Loss: 0.406614
Epoch: 1500 Loss: 0.214218
Epoch: 2000 Loss: 0.141949
Epoch: 2500 Loss: 0.104820
Epoch: 3000 Loss: 0.079142
Epoch: 3500 Loss: 0.060694
Epoch: 4000 Loss: 0.047331
Epoch: 4500 Loss: 0.037541
Epoch: 5000 Loss: 0.030341
Epoch: 5500 Loss: 0.024999
Epoch: 6000 Loss: 0.020991
Epoch: 6500 Loss: 0.017892
Epoch: 7000 Loss: 0.015471
Epoch: 7500 Loss: 0.013532
Epoch: 8000 Loss: 0.011990
Epoch: 8500 Loss: 0.010741
Epoch: 9000 Loss: 0.009695
Epoch: 9500 Loss: 0.008816
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', 'buzz', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', '25', '26', 'fizz', '28', '29', 'fizzbuzz', '31', 'buzz', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', 'buzz', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz'