In [4]:
require 'nn'

torch.setdefaulttensortype('torch.FloatTensor')

MyLayer, parent = torch.class('nn.MyLayer', 'nn.Module')

Input: two numbers $(x,y)$, output: $\left(e^{ax+by}, ~\sin cx \cdot \cos dy\right)$

In [30]:
-- MyLayer definition

function MyLayer:__init()
    parent.__init(self)
    self:zeroGradParameters()
    self:reset()
end

function MyLayer:reset()
    self.a = torch.uniform(-1,1)
    self.b = torch.uniform(-1,1)
    self.c = torch.uniform(-1,1)
    self.d = torch.uniform(-1,1)
end

function MyLayer:updateOutput(input)
    self.output:resize(2)
    
    local x, y = input[1], input[2]
--     print(x,y)
--     print(self.output)
    self.output[1] = torch.exp(self.a*x + self.b*y)
    self.output[2] = torch.sin(self.c*x)*torch.cos(self.d*y)
--     print('here')
    
    return self.output
end

function MyLayer:updateGradInput(input, gradOutput)
    if self.gradInput then
        self.gradInput:resize(2)
        
        self.gradInput[1] = 
            gradOutput[1] * self.a * self.output[1] +
            gradOutput[2] * self.c * 
                torch.cos(self.c * input[1]) *
                torch.cos(self.d * input[2])
        
        self.gradInput[2] = 
            gradOutput[1] * self.b * self.output[1] -
            gradOutput[2] * self.d * 
                torch.sin(self.c * input[1]) *
                torch.sin(self.d * input[2])
        
        return self.gradInput
    end
end

function MyLayer:accGradParameters(input, gradOutput, scale)
    scale = scale or 1
    
    self.gradA = self.gradA + 
        scale * gradOutput[1] * input[1] * self.output[1]
    self.gradB = self.gradB + 
        scale * gradOutput[1] * input[2] * self.output[1]
    
    self.gradC = self.gradC + 
        scale * gradOutput[2] * 
            torch.cos(self.d * input[2]) *
            torch.cos(self.c * input[1]) * input[1]
    self.gradD = self.gradD - 
        scale * gradOutput[2] * 
            torch.sin(self.c * input[1]) *
            torch.sin(self.d * input[2]) * input[2]
end

function MyLayer:zeroGradParameters()
    self.gradA, self.gradB, self.gradC, self.gradD = 0, 0, 0, 0
end

function MyLayer:updateParameters(lr)
    self.a = self.a - lr * self.gradA
    self.b = self.b - lr * self.gradB
    self.c = self.c - lr * self.gradC
    self.d = self.d - lr * self.gradD
end

Test this layer in a 2-layer "perceptron" to classify points in 2D:

In [57]:
dataset = {}
labels = {}

for x = -1,0,0.1 do
    for y = x,-x,0.1 do
        dataset[#dataset+1] = {x, y}
        labels[#labels+1] = 1
    end
end

for x = -0.5,0.5,0.1 do
    for y = -1,x-0.5,0.1 do
        dataset[#dataset+1] = {x, y}
        labels[#labels+1] = 2
    end
    
    for y = 1,0.5-x,-0.1 do
        dataset[#dataset+1] = {x, y}
        labels[#labels+1] = 2
    end
end
-- remove (0.5, 0)
dataset[#dataset] = nil
labels[#labels] = nil

for x = 0.6,1,0.1 do
    for y = -1,1,0.1 do
        dataset[#dataset+1] = {x, y}
        labels[#labels+1] = 2
    end
end

dataset = torch.Tensor(dataset)
labels = torch.Tensor(labels)

dataset[{{}, 1}]:add(-dataset[{{}, 1}]:mean())
dataset[{{}, 2}]:add(-dataset[{{}, 2}]:mean())
dataset[{{}, 1}]:div(dataset[{{}, 1}]:std())
dataset[{{}, 2}]:div(dataset[{{}, 2}]:std())

In [32]:
cv = require 'cv'
require 'cv.highgui'
require 'cv.imgproc'

function drawPredictions(net)
    local size = 600
    img = torch.ByteTensor(size, size, 3):zero()
    
    local step = 0.025
    
    for x = -1,1,step do
        for y = -1,1,step do
            local class = select(2, net:forward(torch.Tensor{x,y}):max(1))[1]
            color = class == 1 and {0,255,0} or {255,100,100}
            cv.rectangle{img, 
                {size*(x+1)/2, size*(y+1)/2},
                {size*(x+step+1)/2, size*(y+step+1)/2}, color, cv.FILLED}
        end
    end
    
    for i = 1,dataset:size(1) do
        local class = select(2, net:forward(dataset[i]):max(1))[1]
        cv.circle{img, ((dataset[i]+1)/2*size):totable(), 3, 
            class == 1 and {0,0,0} or {255,255,255}, cv.FILLED}
    end
    
    cv.imshow{'w', img}
    cv.waitKey{1}
end

In [70]:
net = nn.Sequential()
net:add(nn.MyLayer())
-- net:add(nn.Tanh())
net:add(nn.Linear(2, 2))

-- local L2 = 0.001

crit = nn.CrossEntropyCriterion()

cv.destroyAllWindows{}
cv.namedWindow{'w'}

local manual = true -- manual gradient accumulation, NOT batch mode
local lr = 25e-1

for iter = 1,25 do
    if iter % 50 == 0 then lr = lr * 0.85 end
        
    drawPredictions(net)
    
    if manual then
        local avgErr = 0
        net:zeroGradParameters()
        
        for idx = 1,dataset:size(1) do
--             local idx = torch.random(dataset:size(1))
            local pred = net:forward(dataset[idx])
            local err  = crit:forward(pred, labels[idx])
            avgErr = avgErr + err

            local dLoss_dOutput = crit:backward(pred, labels[idx])
            net:updateGradInput(dataset[idx], dLoss_dOutput)
            net:accGradParameters(
                dataset[idx], 
                dLoss_dOutput, 
                idx == 121 and 50 / dataset:size(1) or 1 / dataset:size(1))
        end
        
        net:updateParameters(lr)
        avgErr = avgErr / dataset:size(1)
        print('Error:', avgErr)
    else
        net:zeroGradParameters()

        local pred = net:forward(dataset)
        local err  = crit:forward(pred, labels)
        print('Error:' .. err)

        local dLoss_dOutput = crit:backward(pred, labels)
        net:backward(dataset, dLoss_dOutput)

    --     net:get(1).gradBias:add(net:get(1).bias*L2)
    --     net:get(1).gradWeight:add(net:get(1).weight*L2)
    --     net:get(3).gradBias:add(net:get(3).bias*L2)
    --     net:get(3).gradWeight:add(net:get(3).weight*L2)

        net:updateParameters(lr)
    end
end

drawPredictions(net)

cv.destroyAllWindows{}

Error:	0.045873041586773	


Error:	0.045528263281865	


Error:	0.045191595748208	


Error:	0.044862643679127	


Error:	0.044541130312067	


Error:	0.044226698686044	


Error:	0.043919077895111	


Error:	0.043617969714089	


Error:	0.043323123857241	


Error:	0.043034292517631	


Error:	0.042751228248501	


Error:	0.042473747327406	


Error:	0.042201617888433	


Error:	0.041934642302361	


Error:	0.041672647543245	


Error:	0.041415461894511	


Error:	0.041162895367854	


Error:	0.040914824529914	


Error:	0.040671071998233	


Error:	0.040431502218557	


Error:	0.040195968359698	


Error:	0.039964371432569	


Error:	0.039736575313653	


Error:	0.039512455420464	


Error:	0.039291908558694	
