In [1]:
from theano import *
import theano.tensor as T
from theano import function
import numpy
from theano.tensor.shared_randomstreams import RandomStreams
from __future__ import print_function

In [2]:
a = T.dmatrix('a')

In [3]:
b = T.fvector('b')

In [4]:
rng = numpy.random

N = 400                                   # training sample size
feats = 784                               # number of input variables

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))
training_steps = 10000

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weight vector w randomly
#
# this and the following bias variable b
# are shared so they keep their values
# between training iterations (updates)
w = theano.shared(rng.randn(feats), name="w")

# initialize the bias term
b = theano.shared(0., name="b")

print("Initial model:")
print(w.get_value())
print(b.get_value())

# Construct Theano expression graph
p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))   # Probability that target = 1
prediction = p_1 > 0.5                    # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
cost = xent.mean() + 0.01 * (w ** 2).sum()# The cost to minimize
gw, gb = T.grad(cost, [w, b])             # Compute the gradient of the cost
                                          # w.r.t weight vector w and
                                          # bias term b
                                          # (we shall return to this in a
                                          # following section of this tutorial)

# Compile
train = theano.function(
          inputs=[x,y],
          outputs=[prediction, xent],
          updates=((w, w - 0.1 * gw), (b, b - 0.1 * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Train
for i in range(training_steps):
    pred, err = train(D[0], D[1])

print("Final model:")
print(w.get_value())
print(b.get_value())
print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))

Initial model:
[  2.02693794e+00   1.40908767e-01  -5.72013194e-01  -2.39651775e+00
  -1.11285239e+00  -1.22400978e-01  -3.07869971e-01  -1.49090589e+00
  -1.94307911e+00  -9.57316412e-01   4.43096417e-01  -1.77075303e+00
   1.46556219e-02  -1.86523422e-01   1.05142481e-01  -3.26526540e-01
  -4.73572015e-01   5.90649412e-01  -1.72177420e-01  -2.12563463e-02
  -3.20941884e-01  -2.02832167e+00  -8.76962258e-01   3.73974074e-01
  -5.21882806e-01   1.05892013e+00  -1.74328379e+00  -8.05824744e-01
   1.67155923e+00   1.72566642e-03  -1.25388722e+00   3.95523194e-01
  -1.32593797e+00   1.54976212e+00  -8.92439343e-02   1.13369655e+00
   2.20686302e-01   5.91549048e-01   2.75624216e-01   4.29328760e-01
  -2.46649594e-01   4.49185194e-01   7.43658801e-01  -1.76765878e+00
   7.49781429e-01   1.48426413e+00  -1.78310175e-01   1.71398879e-01
   1.13525418e+00   3.61820838e-01   4.27310254e-01   1.79571707e-01
   4.87356234e-01  -4.58310966e-01  -3.02526472e-02  -2.16204676e-02
   9.35873212e-01  

Final model:
[  4.50753814e-02  -3.95029099e-02   8.66047342e-02   4.72092943e-02
  -9.61913160e-02   4.82871807e-02  -2.21177863e-01  -2.07236864e-02
   2.34316990e-02  -2.48065848e-02  -5.65083645e-02   1.12632507e-02
   3.66897465e-03  -5.32265847e-02  -6.42577982e-02  -1.38630129e-01
   8.06960534e-02  -3.73001305e-02  -3.59374143e-02  -6.36238540e-03
  -5.07612506e-02  -7.35421454e-02  -2.31823592e-01   3.18913407e-03
  -2.71317318e-02  -8.83623578e-02  -3.80930858e-02  -2.12090784e-02
   7.73462375e-02  -1.51759425e-01   5.79183155e-02  -3.79395655e-02
   3.36229614e-02  -2.30356362e-03   8.11194396e-02   8.66034491e-02
  -4.79897129e-03   1.08602729e-01   1.78861801e-01  -4.06546953e-02
  -1.03295337e-01  -9.50062010e-02   1.22290777e-01   3.76777797e-02
   3.29208100e-02  -5.08082999e-02   5.60012521e-02  -5.79594156e-02
   5.69427852e-02   9.09376970e-02  -6.59869691e-02  -1.13738614e-01
  -2.57469639e-02   5.01265520e-02  -1.09824226e-01   6.40340839e-02
   1.18195248e-01  -5