In [277]:
import numpy as np

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers
from chainer import cuda


class Unsharing(object):
    name = 'Unsharing'
    def __init__(self, rate, source):
        self.rate = rate
        self.source = source

    def __call__(self, opt):
        for t_param, s_param in zip(opt.target.params(False), self.source.params(False)):
            t_data, t_grad = t_param.data, t_param.grad
            s_data = s_param.data
            xp = cuda.get_array_module(t_data)
            t_grad += self.rate * -1 * xp.absolute(t_data - s_data)


class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(1, 1, nobias=True)
#            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
#        h = self.l2(h)
        return F.square(h)


In [325]:
x = np.asarray([[0], [1], [2], [3], [4]], dtype=np.float32)
t = np.asarray([[0], [4], [16], [36], [64]], dtype=np.float32)

In [326]:
net1 = MLP()
optimizer1 = optimizers.SGD(lr=0.001)
optimizer1.setup(net)

In [344]:
y = net1(x)
print(net1.l1.W.data)
loss = F.mean_squared_error(y, t)
net1.cleargrads()
loss.backward()
optimizer1.update()
print(loss)
print(net1.l1.W.data)
print(net.l1.W.grad)

[[-2.3570375]]
variable(171.33401)
[[-2.3570375]]
[[-684.7821]]


In [320]:
print(net.l1.W.data)

[[1.5770602]]


In [321]:
loss = F.mean_squared_error(y, t)
net.cleargrads()
loss.backward()
optimizer.update()

In [322]:
print(net.l1.W.data)

[[2.2527485]]


In [323]:
loss

variable(162.04768)

In [324]:
print(net.l1.W.grad)

[[-675.6881]]


In [109]:
net.l1.W.data = np.asarray([[-2]], dtype=np.float32)

In [452]:
net2 = MLP()
optimizer2 = optimizers.SGD(lr=0.001)
optimizer2.setup(net2)
optimizer2.add_hook(Unsharing(rate=100, source=net))

In [456]:
y = net2(x)
print(net2.l1.W.data)
loss = F.mean_squared_error(y, t)
net2.cleargrads()
loss.backward()
optimizer2.update()
print(net2.l1.W.data)
print(net2.l1.W.grad)

[[2.9904652]]
[[-0.5846801]]
[[3575.145]]


In [395]:
net2.l1.W

variable W([[-0.19406027]])