In [8]:
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import optimizers, training, iterators

In [5]:
class NN(chainer.Chain):
    def __init__(self, n_mid_units=100, n_out=10):
        super().__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_mid_units)
            self.l2 = L.Linear(None, n_mid_units)
            self.l3 = L.Linear(None, n_out)

    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

In [50]:
class NN2(chainer.Chain):
    def __init__(self, n_mid_units=100, n_out=10):
        super().__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_mid_units)
            self.l2 = L.Linear(None, n_mid_units)
            self.l3_ = L.Linear(None, n_out)

    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

In [9]:
from chainer.datasets import mnist

train, test = mnist.get_mnist()
batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, False, False)

Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...


In [10]:
model = NN()

max_epoch = 10

# Wrap your model by Classifier and include the process of loss calculation within your model.
# Since we do not specify a loss function here, the default 'softmax_cross_entropy' is used.
model = L.Classifier(model)

# selection of your optimizing method
optimizer = optimizers.MomentumSGD()

# Give the optimizer a reference to the model
optimizer.setup(model)

# Get an updater that uses the Iterator and Optimizer
updater = training.updaters.StandardUpdater(train_iter, optimizer)

In [12]:
from chainer.training import extensions

trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result')
trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.DumpGraph('main/loss'))

In [13]:
trainer.run()

epoch       main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
[J1           0.545101    0.848398       0.255823              0.926325                  2.11769       
[J2           0.229993    0.934102       0.180875              0.948477                  4.22368       
[J3           0.173434    0.950277       0.144885              0.959256                  6.41378       
[J4           0.139901    0.959719       0.127111              0.961828                  8.56746       
[J5           0.116996    0.966118       0.111384              0.966673                  10.8277       
[J6           0.101077    0.971149       0.106002              0.967959                  13.0318       
[J7           0.0883205   0.974597       0.0970856             0.971519                  15.1305       
[J8           0.0771667   0.977931       0.0900739             0.971519                  17.2369       
[J9           0.0680626   0.979894       0.0890723         

In [51]:
from chainer import serializers

model = NN()
target_model = NN2()
serializers.load_npz('mnist_result/model_epoch-10', model)

In [53]:
for child in model._children:
    if hasattr(target_model, child):
        attr = getattr(target_model, child)
        for name, p in model.__dict__[child].namedparams():
            if hasattr(attr, name[1:]):
                print(child, name[1:])
                print(p)
                attr.__dict__[name[1:]] = p
            else:
                print("Attribute '{}.{}' not found in the model. Failed to load weights".format(child, name[1:]))
    else:
        print("Attribute '{}' not found in the model. Failed to load weights".format(child))

for child in target_model.children():
    print(child.W)
    print(child.b)

Attribute 'l3' not found in the model. Failed to load weights
l1 W
variable W([[-0.04877891 -0.00524843 -0.00363599 ...  0.05494009
             -0.01535325  0.03098945]
            [ 0.00454965  0.02043998 -0.01013051 ...  0.02852278
              0.07175718  0.03172262]
            [ 0.01871255 -0.01793926  0.02691932 ... -0.0008819
             -0.00988861  0.06829702]
            ...
            [-0.04908756 -0.0145371  -0.01691245 ...  0.02448979
              0.0019673   0.03976892]
            [ 0.01391795  0.02018651  0.00824283 ...  0.00967187
              0.04573222  0.02666374]
            [ 0.05059663 -0.03989486  0.00735897 ...  0.02446741
             -0.02581615 -0.01250272]])
l1 b
variable b([ 0.04329899  0.02675249  0.01609204  0.06084536  0.03415363
            -0.01458526 -0.02003986  0.06404594  0.02512664  0.03988663
            -0.02014436 -0.01787255  0.00639743  0.04627038  0.01835883
             0.05692839  0.04853433 -0.00808021  0.03991929 -0.0148244
      