In [1]:
import chainer

In [2]:
from chainer.dataset import convert

In [3]:
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import reporter

In [4]:
import rot
from rot import rotation3d

In [5]:
import numpy as np

class PreprocessedDataset(chainer.dataset.DatasetMixin):
    def __init__(self, nb_data):
        self.x = np.random.rand(nb_data, 3).astype(np.float32)
        r = np.random.rand(3).astype(np.float32)
        t = np.random.rand(1, 3).astype(np.float32)
        self.y = []
        for x in self.x:
            self.y.append((self.rotation(x, r) + t).reshape(-1))
        self.y = np.array(self.y)

    def rodrigues(self, r):
            def S(n):
                Sn = np.array([[0,-n[2],n[1]],[n[2],0,-n[0]],[-n[1],n[0],0]])
                return Sn

            theta = np.linalg.norm(r)

            if theta > 1e-16:
                n = r / theta
                Sn = S(n)
                R = np.eye(3) + \
                    np.sin(theta) * Sn + \
                    (1 - np.cos(theta)) * np.dot(Sn, Sn)
            else:
                Sr = S(r)
                theta2 = theta**2
                R = np.eye(3) + \
                    (1- theta2/6.) * Sr + \
                    (.5 - theta2/24.) * np.dot(Sr, Sr)

            return R.astype(r.dtype)

    def rotation(self, x, r):
        rmat = self.rodrigues(r)
        return x.dot(rmat.T).astype(x.dtype, copy=False)

    def __len__(self):
        return len(self.x)

    def get_example(self, i):
        return self.x[i], self.y[i]

In [6]:
class Net(chainer.Chain):
    def __init__(self):
        super(Net, self).__init__()
        
        with self.init_scope():
            self.embd = L.EmbedID(2, 3)

    def __call__(self, x):
        xp = chainer.cuda.get_array_module(x.data)            
        r = self.embd(xp.array([0], dtype=np.int32))
        r = F.reshape(r, (3,))
        
        t = self.embd(xp.array([1], dtype=np.int32))
        t = F.broadcast_to(t, x.shape)
            
        return rotation3d(x, r) + t

In [7]:
data = PreprocessedDataset(nb_data = 1000)

In [8]:
net = Net()

In [9]:
class loss_function(chainer.link.Chain):
    def __init__(self, predictor):
        super(loss_function, self).__init__(predictor=predictor)

    def __call__(self, x, y):
        py = self.predictor(x)
        self.loss = F.mean_squared_error(py, y)
        reporter.report({'loss': self.loss}, self)
        return self.loss

In [10]:
model = loss_function(net)

In [11]:
optimizer = chainer.optimizers.SGD(lr=1)
optimizer.setup(model)

In [12]:
data_iter = chainer.iterators.SerialIterator(data, len(data))
data_count = len(data)

sum_loss = 0

while data_iter.epoch < 500:
    batch = data_iter.next()
    x_array, y_array = convert.concat_examples(batch, -1)
    x = chainer.Variable(x_array)
    y = chainer.Variable(y_array)
    optimizer.update(model, x, y)
    sum_loss += float(model.loss.data) * len(y.data)

    if data_iter.is_new_epoch:
        if data_iter.epoch % 50 == 0:
            print('epoch: {}, train mean loss: {}'.format(data_iter.epoch, sum_loss / data_count))
        sum_loss = 0

epoch: 50, train mean loss: 0.00011391683074180037
epoch: 100, train mean loss: 2.518971768949996e-07
epoch: 150, train mean loss: 4.954581744875952e-10
epoch: 200, train mean loss: 8.143964018783589e-13
epoch: 250, train mean loss: 6.0233368382530024e-15
epoch: 300, train mean loss: 6.0233368382530024e-15
epoch: 350, train mean loss: 6.0233368382530024e-15
epoch: 400, train mean loss: 6.0233368382530024e-15
epoch: 450, train mean loss: 6.0233368382530024e-15
epoch: 500, train mean loss: 6.0233368382530024e-15
