# FizzBuzz

FizzBuzz是一个简单的小游戏。游戏规则如下：从1开始往上数数，当遇到3的倍数的时候，说fizz，当遇到5的倍数，说buzz，当遇到15的倍数，就说fizzbuzz，其他情况下则正常数数。

我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。

In [1]:
# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0
    
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

print(fizz_buzz_decode(1, fizz_buzz_encode(1)))
print(fizz_buzz_decode(2, fizz_buzz_encode(2)))
print(fizz_buzz_decode(5, fizz_buzz_encode(5)))
print(fizz_buzz_decode(12, fizz_buzz_encode(12)))
print(fizz_buzz_decode(15, fizz_buzz_encode(15)))

1
2
buzz
fizz
fizzbuzz


我们首先定义模型的输入与输出(训练数据)

In [2]:
import numpy as np
import torch

NUM_DIGITS = 10

# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

然后我们用PyTorch定义模型

In [3]:
# Define the model
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)

- 为了让我们的模型学会FizzBuzz这个游戏，我们需要定义一个损失函数，和一个优化算法。
- 这个优化算法会不断优化（降低）损失函数，使得模型的在该任务上取得尽可能低的损失值。
- 损失值低往往表示我们的模型表现好，损失值高表示我们的模型表现差。
- 由于FizzBuzz游戏本质上是一个分类问题，我们选用Cross Entropyy Loss函数。
- 优化函数我们选用Stochastic Gradient Descent。

In [4]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

以下是模型的训练代码

In [5]:
# Start training it
BATCH_SIZE = 128 #每128个数据截取一次数据
for epoch in range(10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY) #每个batch都有一个loss？

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Find loss on training data
    loss = loss_fn(model(trX), trY).item() #拿所有的loss？
    print('Epoch:', epoch, 'Loss:', loss)
# batchsize：批处理大小。一次训练所选取的样本数。 它的大小影响模型的优化程度和速度。
# Iteration：迭代次数。一次Iteration就是batchsize个训练数据前向传播和反向传播后更新参数的过程。
# epoch：所有训练数据前向传播和反向传播后更新参数的过程。也就是我们认为的所有数据集跑了一遍。
# 如果训练集大小是100000。batchsize为100，那么一个epoch需要1000次Iteration
# https://www.cnblogs.com/dtpromise/articles/11484540.html

Epoch: 0 Loss: 1.1792891025543213
Epoch: 1 Loss: 1.1541941165924072
Epoch: 2 Loss: 1.1471019983291626
Epoch: 3 Loss: 1.1443066596984863
Epoch: 4 Loss: 1.1429307460784912
Epoch: 5 Loss: 1.1421258449554443
Epoch: 6 Loss: 1.1415908336639404
Epoch: 7 Loss: 1.1411938667297363
Epoch: 8 Loss: 1.1408807039260864
Epoch: 9 Loss: 1.1406155824661255
Epoch: 10 Loss: 1.1403849124908447
Epoch: 11 Loss: 1.1401779651641846
Epoch: 12 Loss: 1.1399884223937988
Epoch: 13 Loss: 1.1398112773895264
Epoch: 14 Loss: 1.1396459341049194
Epoch: 15 Loss: 1.1394920349121094
Epoch: 16 Loss: 1.1393471956253052
Epoch: 17 Loss: 1.1392076015472412
Epoch: 18 Loss: 1.1390762329101562
Epoch: 19 Loss: 1.138946533203125
Epoch: 20 Loss: 1.1388249397277832
Epoch: 21 Loss: 1.1387059688568115
Epoch: 22 Loss: 1.1385916471481323
Epoch: 23 Loss: 1.1384828090667725
Epoch: 24 Loss: 1.1383742094039917
Epoch: 25 Loss: 1.1382704973220825
Epoch: 26 Loss: 1.1381683349609375
Epoch: 27 Loss: 1.138067603111267
Epoch: 28 Loss: 1.13797068595886

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

In [1]:
# Output now
testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
with torch.no_grad():
    testY = model(testX)
predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for (i, x) in predictions])

NameError: name 'torch' is not defined

In [7]:
print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])))
testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])

96


array([ True,  True,  True, False,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
       False])