In [1]:
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L


In [2]:
# Set data
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data.astype(np.float32)
Y = iris.target.astype(np.int32)
N = Y.size
index = np.arange(N)
xtrain = X[index[index % 2 != 0],:]
ytrain = Y[index[index % 2 != 0]]
xtest = X[index[index % 2 == 0],:]
yans = Y[index[index % 2 == 0]]


In [3]:
# Define model

class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1=L.Linear(4,6),
            l2=L.Linear(6,3),
        )

    def __call__(self,x,y):
        return F.softmax_cross_entropy(self.fwd(x), y)

    def fwd(self,x):
         h1 = F.sigmoid(self.l1(x))
         h2 = self.l2(h1)
         return h2

In [4]:
# Initialize model

model = IrisChain()
optimizer = optimizers.SGD()
optimizer.setup(model)

In [5]:
# Learn

n = 75
bs = 25
for j in range(2000):
    sffindx = np.random.permutation(n)
    for i in range(0, n, bs):
        x = Variable(xtrain[sffindx[i:(i+bs) if (i+bs) < n else n]])
        y = Variable(ytrain[sffindx[i:(i+bs) if (i+bs) < n else n]])
        model.zerograds()
        loss = model(x,y)
        loss.backward()
        optimizer.update()

In [6]:
# Test

xt = Variable(xtest, volatile='on')
yy = model.fwd(xt)

ans = yy.data
nrow, ncol = ans.shape
ok = 0
for i in range(nrow):
    cls = np.argmax(ans[i,:])
    print (ans[i,:], cls)
    if cls == yans[i]:
        ok += 1

print( ok, "/", nrow, " = ", (ok * 1.0)/nrow)

[ 1.9342134  -0.60227877 -2.87409687] 0
[ 1.90794182 -0.58438176 -2.84659743] 0
[ 1.93580866 -0.60316813 -2.87522244] 0
[ 1.88876843 -0.58118999 -2.83257794] 0
[ 1.82086635 -0.54761857 -2.77366328] 0
[ 1.94463444 -0.6111989  -2.88639021] 0
[ 1.88670659 -0.576756   -2.83155251] 0
[ 1.98220468 -0.62454122 -2.91648459] 0
[ 1.96138859 -0.61633199 -2.89821148] 0
[ 1.93166327 -0.61127442 -2.88000226] 0
[ 1.89401102 -0.59285063 -2.84804034] 0
[ 1.96493983 -0.60218763 -2.89040828] 0
[ 1.77526855 -0.54725504 -2.75452852] 0
[ 1.86074877 -0.57734454 -2.81659174] 0
[ 1.93256497 -0.6013478  -2.87293625] 0
[ 1.83141565 -0.56145757 -2.79167819] 0
[ 1.96003103 -0.6186313  -2.8995769 ] 0
[ 1.88221955 -0.57905871 -2.83120537] 0
[ 1.95871687 -0.61188483 -2.8948102 ] 0
[ 1.86504078 -0.56306428 -2.8075695 ] 0
[ 1.93598104 -0.60070658 -2.87317419] 0
[ 1.88759518 -0.57409751 -2.82715034] 0
[ 1.84021306 -0.57680106 -2.80694795] 0
[ 1.92541099 -0.60470295 -2.87096286] 0
[ 1.94098473 -0.60933    -2.88293862] 0
