In [4]:
import numpy as np
import matplotlib.pyplot as plt
import dezero
from dezero import Model
from dezero import SeqDataLoader
import dezero.functions as F
import dezero.layers as L


max_epoch = 100
batch_size = 30
hidden_size = 100
bptt_length = 30

train_set = dezero.datasets.SinCurve(train=True)
dataloader = SeqDataLoader(train_set, batch_size=batch_size)
seqlen = len(train_set)


class BetterRNN(Model):
    def __init__(self, hidden_size, out_size):
        super().__init__()
        self.rnn = L.LSTM(hidden_size)
        self.fc = L.Linear(out_size)

    def reset_state(self):
        self.rnn.reset_state()

    def __call__(self, x):
        y = self.rnn(x)
        y = self.fc(y)
        return y

model = BetterRNN(hidden_size, 1)
optimizer = dezero.optimizers.Adam().setup(model)

for epoch in range(max_epoch):
    model.reset_state()
    loss, count = 0, 0

    for x, t in dataloader:
        print(x, t)
        y = model(x)
        loss += F.mean_squared_error(y, t)
        count += 1

        if count % bptt_length == 0 or count == seqlen:
            model.cleargrads()
            loss.backward()
            loss.unchain_backward()
            optimizer.update()
    avg_loss = float(loss.data) / count
    print('| epoch %d | loss %f' % (epoch + 1, avg_loss))

# Plot
xs = np.cos(np.linspace(0, 4 * np.pi, 1000))
model.reset_state()
pred_list = []

with dezero.no_grad():
    for x in xs:
        x = np.array(x).reshape(1, 1)
        y = model(x)
        pred_list.append(float(y.data))

plt.plot(np.arange(len(xs)), xs, label='y=cos(x)')
plt.plot(np.arange(len(xs)), pred_list, label='predict')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

[[-0.02586095]
 [ 0.20297278]
 [ 0.43455673]
 [ 0.55253686]
 [ 0.7735074 ]
 [ 0.90003375]
 [ 0.97463924]
 [ 0.98037676]
 [ 1.01344664]
 [ 0.98822586]
 [ 0.90348027]
 [ 0.78702054]
 [ 0.61883588]
 [ 0.41376095]
 [ 0.21662224]
 [ 0.0134528 ]
 [-0.19968589]
 [-0.36862853]
 [-0.54297709]
 [-0.73318676]
 [-0.88696287]
 [-0.96586129]
 [-1.01996315]
 [-0.99905739]
 [-0.97348797]
 [-0.9328483 ]
 [-0.74218042]
 [-0.63282008]
 [-0.45073782]
 [-0.30103157]] [[-0.00373502]
 [ 0.17666931]
 [ 0.42816028]
 [ 0.60577272]
 [ 0.76356497]
 [ 0.82724227]
 [ 0.9309286 ]
 [ 0.95425979]
 [ 0.95118541]
 [ 0.95878146]
 [ 0.91512725]
 [ 0.74182392]
 [ 0.62219063]
 [ 0.43940572]
 [ 0.26870245]
 [ 0.03364688]
 [-0.16781507]
 [-0.35736482]
 [-0.59306989]
 [-0.76244848]
 [-0.8166672 ]
 [-0.97807529]
 [-1.01394554]
 [-0.97778058]
 [-0.97904851]
 [-0.87575978]
 [-0.79679872]
 [-0.65859698]
 [-0.40381556]
 [-0.20751005]]
[[-0.00373502]
 [ 0.17666931]
 [ 0.42816028]
 [ 0.60577272]
 [ 0.76356497]
 [ 0.82724227]
 [ 0.930

KeyboardInterrupt: 