From f45992f1b6586fea1addaef3fe3cbaf693e7a3f8 Mon Sep 17 00:00:00 2001 From: fuzihao Date: Wed, 22 Mar 2017 14:05:31 +0800 Subject: [PATCH 1/8] first commit --- time_sequence_prediction/README.md | 1 + .../generate_sine_wave.py | 12 +++ time_sequence_prediction/train.py | 81 +++++++++++++++++++ 3 files changed, 94 insertions(+) create mode 100644 time_sequence_prediction/README.md create mode 100644 time_sequence_prediction/generate_sine_wave.py create mode 100644 time_sequence_prediction/train.py diff --git a/time_sequence_prediction/README.md b/time_sequence_prediction/README.md new file mode 100644 index 0000000000..e6af43f461 --- /dev/null +++ b/time_sequence_prediction/README.md @@ -0,0 +1 @@ +Time Sequence Prediction diff --git a/time_sequence_prediction/generate_sine_wave.py b/time_sequence_prediction/generate_sine_wave.py new file mode 100644 index 0000000000..8a68a57b5d --- /dev/null +++ b/time_sequence_prediction/generate_sine_wave.py @@ -0,0 +1,12 @@ +import math +import numpy as np +import cPickle as pickle +T = 20 +L = 1000 +N = 100 +np.random.seed(2) +x = np.empty((N, L), 'int64') +x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1) +y = np.sin(x / 1.0 / T).astype('float64') +pickle.dump(y, open('traindata.pkl', 'wb')) + diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py new file mode 100644 index 0000000000..60036f9914 --- /dev/null +++ b/time_sequence_prediction/train.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import cPickle as pickle +import torch.optim as optim +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +class Sequence(nn.Module): + def __init__(self): + super(Sequence, self).__init__() + self.lstm1 = nn.LSTMCell(1, 51) + self.lstm2 = nn.LSTMCell(51, 1) + + def forward(self, input, future = 0): + outputs = [] + h_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False) + c_t = Variable(torch.zeros(input.size(0), 51).double(), requires_grad=False) + h_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False) + c_t2 = Variable(torch.zeros(input.size(0), 1).double(), requires_grad=False) + + for i, input_t in enumerate(input.chunk(input.size(1), dim=1)): + h_t, c_t = self.lstm1(input_t, (h_t, c_t)) + h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) + outputs += [c_t2] + for i in range(future): + h_t, c_t = self.lstm1(c_t2, (h_t, c_t)) + h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) + outputs += [c_t2] + outputs = torch.stack(outputs, 1).squeeze(2) + return outputs + + + +if __name__ == '__main__': + # set ramdom seed to 0 + np.random.seed(0) + torch.manual_seed(0) + # load data and make training set + data = pickle.load(open('traindata.pkl')) + input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False) + target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False) + # build the model + seq = Sequence() + seq.double() + criterion = nn.MSELoss() + # use LBFGS as optimizer since we can load the whole data to train + optimizer = optim.LBFGS(seq.parameters()) + #begin to train + for i in range(20): + print i + def closure(): + optimizer.zero_grad() + out = seq(input) + loss = criterion(out, target) + print 'loss:', loss.data.numpy()[0] + loss.backward() + return loss + optimizer.step(closure) + # begin to predict + future = 1000 + pred = seq(input[:3], future = future) + y = pred.data.numpy() + # draw the result + plt.figure(figsize=(30,10)) + plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30) + plt.xlabel('x', fontsize=20) + plt.ylabel('y', fontsize=20) + plt.xticks(fontsize=20) + plt.yticks(fontsize=20) + def draw(yi, color): + plt.plot(np.arange(input.size(1)), yi[:input.size(1)], color, linewidth = 2.0) + plt.plot(np.arange(input.size(1), input.size(1) + future), yi[input.size(1):], color + ':', linewidth = 2.0) + draw(y[0], 'r') + draw(y[1], 'g') + draw(y[2], 'b') + plt.savefig('predict%d.pdf'%i) + plt.close() + From 5689735aacec9381de00695b1095d770e7f5a6e2 Mon Sep 17 00:00:00 2001 From: fuzihao Date: Wed, 22 Mar 2017 14:16:36 +0800 Subject: [PATCH 2/8] add some comment. --- time_sequence_prediction/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index 60036f9914..461d517a42 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -25,7 +25,7 @@ def forward(self, input, future = 0): h_t, c_t = self.lstm1(input_t, (h_t, c_t)) h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) outputs += [c_t2] - for i in range(future): + for i in range(future):# if we should predict the future h_t, c_t = self.lstm1(c_t2, (h_t, c_t)) h_t2, c_t2 = self.lstm2(c_t, (h_t2, c_t2)) outputs += [c_t2] @@ -49,13 +49,13 @@ def forward(self, input, future = 0): # use LBFGS as optimizer since we can load the whole data to train optimizer = optim.LBFGS(seq.parameters()) #begin to train - for i in range(20): - print i + for i in range(10): + print('STEP: ', i) def closure(): optimizer.zero_grad() out = seq(input) loss = criterion(out, target) - print 'loss:', loss.data.numpy()[0] + print('loss:', loss.data.numpy()[0]) loss.backward() return loss optimizer.step(closure) From 40d416f321083f5a06c69d26ffc2ffed6cd4f6a7 Mon Sep 17 00:00:00 2001 From: fuzihao Date: Wed, 22 Mar 2017 14:19:55 +0800 Subject: [PATCH 3/8] modify print --- time_sequence_prediction/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index 461d517a42..a5c3718403 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -1,3 +1,4 @@ +from __future__ import print_function import torch import torch.nn as nn from torch.autograd import Variable From 048c66e2644436beb83e1448990d89996ee9ce08 Mon Sep 17 00:00:00 2001 From: Zihao Fu Date: Wed, 22 Mar 2017 14:31:53 +0800 Subject: [PATCH 4/8] Update README.md --- time_sequence_prediction/README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/time_sequence_prediction/README.md b/time_sequence_prediction/README.md index e6af43f461..6353f003d1 100644 --- a/time_sequence_prediction/README.md +++ b/time_sequence_prediction/README.md @@ -1 +1,13 @@ -Time Sequence Prediction +# Time Sequence Prediction +This is a toy example for beginners to start with both pytorch and time sequence prediction. Two LSTMCell is used in this example to learn some sine wave signals starting at different phase. After learning the sine waves, the network try to predict the signals in the future. The results is shown in the picture below. + +## Usage + +``` +python generate_sine_wave.py +python train.py +``` + +## Result +The initial signal and the prediction results are shown in the image. We firstly give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves. +![image](https://cloud.githubusercontent.com/assets/1419566/24184438/e24f5280-0f08-11e7-8f8b-4d972b527a81.png) From f3cddf829a869f9c875a843bcf1c4079376c2c1b Mon Sep 17 00:00:00 2001 From: Zihao Fu Date: Wed, 22 Mar 2017 14:33:06 +0800 Subject: [PATCH 5/8] Update README.md --- time_sequence_prediction/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/time_sequence_prediction/README.md b/time_sequence_prediction/README.md index 6353f003d1..72dbf61762 100644 --- a/time_sequence_prediction/README.md +++ b/time_sequence_prediction/README.md @@ -1,5 +1,5 @@ # Time Sequence Prediction -This is a toy example for beginners to start with both pytorch and time sequence prediction. Two LSTMCell is used in this example to learn some sine wave signals starting at different phase. After learning the sine waves, the network try to predict the signals in the future. The results is shown in the picture below. +This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell is used in this example to learn some sine wave signals starting at different phase. After learning the sine waves, the network try to predict the signals in the future. The results is shown in the picture below. ## Usage From 4804aa7efaa611f0a8066d9d1e3b7c63211061dc Mon Sep 17 00:00:00 2001 From: fuzihao Date: Wed, 22 Mar 2017 14:46:38 +0800 Subject: [PATCH 6/8] modify step number --- time_sequence_prediction/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index a5c3718403..ee88c97d12 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -50,7 +50,7 @@ def forward(self, input, future = 0): # use LBFGS as optimizer since we can load the whole data to train optimizer = optim.LBFGS(seq.parameters()) #begin to train - for i in range(10): + for i in range(15): print('STEP: ', i) def closure(): optimizer.zero_grad() From e4d2040cc7293fd7a0f777e8abd9a0a82464d6ad Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Tue, 4 Apr 2017 23:00:06 -0400 Subject: [PATCH 7/8] Update README.md --- time_sequence_prediction/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/time_sequence_prediction/README.md b/time_sequence_prediction/README.md index 72dbf61762..d2fccf0cb5 100644 --- a/time_sequence_prediction/README.md +++ b/time_sequence_prediction/README.md @@ -1,5 +1,5 @@ # Time Sequence Prediction -This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell is used in this example to learn some sine wave signals starting at different phase. After learning the sine waves, the network try to predict the signals in the future. The results is shown in the picture below. +This is a toy example for beginners to start with. It is helpful for learning both pytorch and time sequence prediction. Two LSTMCell units are used in this example to learn some sine wave signals starting at different phases. After learning the sine waves, the network tries to predict the signal values in the future. The results is shown in the picture below. ## Usage @@ -9,5 +9,5 @@ python train.py ``` ## Result -The initial signal and the prediction results are shown in the image. We firstly give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves. +The initial signal and the predicted results are shown in the image. We first give some initial signals (full line). The network will subsequently give some predicted results (dash line). It can be concluded that the network can generate new sine waves. ![image](https://cloud.githubusercontent.com/assets/1419566/24184438/e24f5280-0f08-11e7-8f8b-4d972b527a81.png) From fbc45a79a5ca5cf6d1689bf92d16691bb73dfeeb Mon Sep 17 00:00:00 2001 From: fuzihao Date: Wed, 5 Apr 2017 14:09:55 +0800 Subject: [PATCH 8/8] remove cPickle --- time_sequence_prediction/generate_sine_wave.py | 6 +++--- time_sequence_prediction/train.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/time_sequence_prediction/generate_sine_wave.py b/time_sequence_prediction/generate_sine_wave.py index 8a68a57b5d..66dd941701 100644 --- a/time_sequence_prediction/generate_sine_wave.py +++ b/time_sequence_prediction/generate_sine_wave.py @@ -1,12 +1,12 @@ import math import numpy as np -import cPickle as pickle +import torch T = 20 L = 1000 N = 100 np.random.seed(2) x = np.empty((N, L), 'int64') x[:] = np.array(range(L)) + np.random.randint(-4*T, 4*T, N).reshape(N, 1) -y = np.sin(x / 1.0 / T).astype('float64') -pickle.dump(y, open('traindata.pkl', 'wb')) +data = np.sin(x / 1.0 / T).astype('float64') +torch.save(data, open('traindata.pt', 'wb')) diff --git a/time_sequence_prediction/train.py b/time_sequence_prediction/train.py index ee88c97d12..1aa99222a8 100644 --- a/time_sequence_prediction/train.py +++ b/time_sequence_prediction/train.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn from torch.autograd import Variable -import cPickle as pickle import torch.optim as optim import numpy as np import matplotlib @@ -40,7 +39,7 @@ def forward(self, input, future = 0): np.random.seed(0) torch.manual_seed(0) # load data and make training set - data = pickle.load(open('traindata.pkl')) + data = torch.load(open('traindata.pt')) input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False) target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False) # build the model