Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion OpenNMT/onmt/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, srcData, tgtData, batchSize, cuda):
self.cuda = cuda

self.batchSize = batchSize
self.numBatches = len(self.src) // batchSize
self.numBatches = (len(self.src) + batchSize - 1) // batchSize

def _batchify(self, data, align_right=False):
max_length = max(x.size(0) for x in data)
Expand Down
17 changes: 17 additions & 0 deletions time_sequence_prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Time Sequence Prediction by pytorch
This is an simple and illustrative example showing how to model the time sequence with LSTM(Long Short-Term Memory) . Several sine waves are input into the model and the modle learns how to generate the future waves according to given states. The result is as following.
![image](https://cloud.githubusercontent.com/assets/1419566/23689065/1d6e9900-03f3-11e7-958b-80066f2e9472.png)

## Usage
1. Generate the training data
```
python generate_sine_wave.py
```

2. Train the model and predict the future states
```
python train.py
```

## The model
Stacked LSTM nodes are used to learn the patterns of the input signal.
12 changes: 12 additions & 0 deletions time_sequence_prediction/generate_sine_wave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import math
import numpy as np
import cPickle as pickle
T = 20
L = 1000
N = 100
np.random.seed(2)
x = np.empty((N, L), 'int64')
x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1)
y = np.sin(x / 1.0 / T).astype('float64')
pickle.dump(y, open('traindata.pkl', 'wb'))

81 changes: 81 additions & 0 deletions time_sequence_prediction/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import cPickle as pickle
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 1)

def forward(self, input, future = 0):
outputs = []
h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False)
h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)
c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False)

for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
for i in range(future):
h_t, c_t = self.lstm1(c_t2, (h_t, c_t))
h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2))
outputs += [c_t2]
outputs = torch.stack(outputs, 1).squeeze(2)
return outputs



if __name__ == '__main__':
# set ramdom seed to 0
np.random.seed(0)
torch.manual_seed(0)
# load data and make training set
data = pickle.load(open('traindata.pkl'))
input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False)
target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False)
# build the model
seq = Sequence()
seq.double()
criterion = nn.MSELoss()
# use LBFGS as optimizer since we can load the whole data to train
optimizer = optim.LBFGS(seq.parameters())
#begin to train
for i in range(20):
print i
def closure():
optimizer.zero_grad()
out = seq(input)
loss = criterion(out, target)
print 'loss:', loss.data.numpy()[0]
loss.backward()
return loss
optimizer.step(closure)
# begin to predict
future = 1000
pred = seq(input[:3], future = future)
y = pred.data.numpy()
# draw the result
plt.figure(figsize=(30,10))
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
plt.xlabel('x', fontsize=20)
plt.ylabel('y', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
def draw(yi, color):
plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0)
plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0)
draw(y[0], 'r')
draw(y[1], 'g')
draw(y[2], 'b')
plt.savefig('predict%d.pdf'%i)
plt.close()