-
Notifications
You must be signed in to change notification settings - Fork 1
/
figure8.py
130 lines (110 loc) · 4.98 KB
/
figure8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -----------------------------------------------------------------------------
# Gated working memory with an echo state network
# Copyright (c) 2018 Nicolas P. Rougier
#
# Distributed under the terms of the BSD License.
# -----------------------------------------------------------------------------
# This script tests stability of the reservoir. The protocol is as follows:
#
# 1. Train the model using teacher forcing (-> Wout)
# 2. For output o in [-5,+5]
# Choose a random reservoir state
# Remove input and force output o at t=0
# Iterate over 500 timesteps
#
# Expected behavior (after 500 timesteps):
# For output(t=0) in [-1,1], no change in output
# For output(t=0 > +1, output converges towards +1
# For output(t=0) < -1, output converges towards -1
# -----------------------------------------------------------------------------
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from data import generate_data, smoothen
from model import generate_model, train_model, test_model
from mpl_toolkits.axes_grid1 import make_axes_locatable
if __name__ == '__main__':
# Random generator initialization
np.random.seed(123)
# Build memory
n_gate = 1
model = generate_model(shape=(1+n_gate,1000,n_gate),
sparsity=0.5,
radius= 0.1,
scaling=1.0,
leak=1.0,
noise=(0.000,0.0001,0.000))
# Training data
n = 25000
values = np.random.uniform(-1, +1, n)
ticks = np.random.uniform(0, 1, (n, n_gate)) < 0.01
train_data = generate_data(values, ticks)
error = train_model(model, train_data)
print("Training error : {0}".format(error))
# Test model with random initial state and no input
np.random.seed(123)
n_value = 100
n_epoch = 500
vmin, vmax = -5, +5
outputs = np.zeros((n_value,n_epoch))
internals_init = np.zeros((n_value,model["shape"][1]))
internals_end = np.zeros((n_value,model["shape"][1]))
for i in tqdm.trange(n_value):
output = vmin + (vmax-vmin)*(i/(n_value-1)) * np.ones(1)
internals = np.random.uniform(-0.5,+0.5,1000)
internals_ = np.tanh((np.dot(model["W_rc"], internals) +
np.dot(model["W_fb"], output)) +
np.dot(model["W_in"], [output, 1])
)
internals = (1-model["leak"])*internals + model["leak"]*internals_
for j in range(n_epoch):
outputs[i,j] = output
internals_ = np.tanh((np.dot(model["W_rc"], internals) +
np.dot(model["W_fb"], output)))
internals = (1-model["leak"])*internals + model["leak"]*internals_
output = np.dot(model["W_out"], internals)
if j == 250:
internals_init[i] = internals
internals_end[i] = internals
# Display results
plt.figure(figsize=(10,5))
ax = plt.subplot(1,1,1)
for i in range(n_value):
ax.plot(outputs[i], color="k", alpha=0.25, lw=0.5)
ax.text(0.98, 0.02, "B", transform=ax.transAxes,
ha="right", va="bottom", fontsize=24, weight="bold")
ax.axvline(250, 0, 1, color="0.75", linewidth=0.75, zorder=-10, linestyle = "--")
divider = make_axes_locatable(ax)
ax = divider.append_axes("right", 1.2, pad=0.1, sharey=ax)
ax.plot(np.abs(outputs[:,-1] - outputs[:,0]),
np.linspace(vmin,vmax,n_value), color="k")
for label in ax.get_yticklabels():
label.set_visible(False)
ax.axhline(+1.0, color="0.75", linewidth=0.75, zorder=-10)
ax.axhline(-1.0, color="0.75", linewidth=0.75, zorder=-10)
ax.axvline( 0.0, color="0.75", linewidth=0.75, zorder=-10)
ax.text(0.125, 0.02, "C", transform=ax.transAxes,
ha="left", va="bottom", fontsize=24, weight="bold")
plt.plot([0,0], [-1,1], lw="1.5", color="red", zorder=-10)
ax.set_xlim([-1, 5.])
ax = divider.append_axes("right", 1.2, pad=0.1, sharey=ax)
ax.plot(np.linalg.norm(internals_init - internals_end, ord = np.inf, axis = 1),
np.linspace(vmin,vmax,n_value), color="k")
for label in ax.get_yticklabels():
label.set_visible(False)
ax.axhline(+1.0, color="0.75", linewidth=0.75, zorder=-10)
ax.axhline(-1.0, color="0.75", linewidth=0.75, zorder=-10)
ax.axvline( 0.0, color="0.75", linewidth=0.75, zorder=-10)
ax.text(0.125, 0.02, "D", transform=ax.transAxes,
ha="left", va="bottom", fontsize=24, weight="bold")
plt.plot([0,0], [-1,1], lw="1.5", color="red", zorder=-10)
ax.set_xlim([-0.1, 2.1])
ax = divider.append_axes("left", 1.2, pad=0.1, sharey=ax)
print(outputs.shape)
for i in range(n_value):
ax.plot(outputs[i, :5], color="k", alpha=0.25, lw=0.5)
ax.text(0.98, 0.02, "A", transform=ax.transAxes,
ha="right", va="bottom", fontsize=24, weight="bold")
plt.tight_layout()
plt.savefig("figure8.pdf")
plt.show()