In [1]:
using Plots, Distances, Knet
plotly()

Plots.PlotlyBackend()

In [96]:
using Knet

function sigmoid(z) 
    return 1.0/(1.0+exp(-z)) 
end

function relu(z)
    return max(z,0)
end

function predict(w,x)
    n = length(w)
    for i = 1:2:n-3
        x = tanh.(w[i]*x .+ w[i+1])
    end
    x = tanh.(w[end-1]*x .+ w[end])
    return x
end

function loss1(w,x,y)
    (d,n) = size(y)
    ypred = predict(w,x)
    l=0
    Y = y-ypred
    for i = 1:n
        l += vecnorm(Y[:,i])/n
    end
    return l
end

function totalVar(y,ypred)
    
    d,m = size(y)
    n = m
    s1 = 2.0/100
    s2 = 2.0/100
    
    min1 = -1
    min2 = -1
    
    l=0
    
    Y = zeros(100+1,100+1)
    X = zeros(100+1,100+1)
    for k = 1:m
        i = Int(floor((y[1,k]-min1)/s1))+1
        j = Int(floor((y[2,k]-min2)/s2))+1
        Y[i,j] += vecnorm(y[:,k],2)
        
        r = Int(floor((ypred[1,k]-min1)/s1))+1
        s = Int(floor((ypred[2,k]-min2)/s2))+1
        ypred[:,k] .-= 0
        @show(ypred[:,k])
        X[r,s] += (sum(ypred[:,k].^2)).^(0.5)
    end
    l += vecnorm(Y-X,2)/(100)
    
    return l
end

function loss2(w,x,y)
    ypred = predict(w,x)
    
    d,m = size(y)
    n = m
    s1 = 2.0/100
    s2 = 2.0/100
    
    min1 = -1
    min2 = -1
    
    l=0
    
    Y = zeros(100+1,100+1)
    X = zeros(100+1,100+1)
    for k = 1:m
        i = Int(floor((y[1,k]-min1)/s1))+1
        j = Int(floor((y[2,k]-min2)/s2))+1
        Y[i,j] += vecnorm(y[:,k] - ypred[:,k],2)
        
        r = Int(floor((ypred[1,k]-min1)/s1))+1
        s = Int(floor((ypred[2,k]-min2)/s2))+1
        X[r,s] += vecnorm(y[:,k] - ypred[:,k],2)
    end
    l += vecnorm(Y-X,2)/(100)
    
    return l
end

lossgradient = grad(loss1)

function train(w,xtrn, ytrn, mu=1e-2, epochs=10)
    p = optimizers(w,Adam; lr = mu)
    
    for epoch=1:epochs
        dtrn = minibatch(xtrn,ytrn,100;shuffle=true)
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            Knet.update!(w,g,p)
        end
        if epoch%10 == 0
            mu = mu/2
            report(epoch,w)
        end
    end
    report(epochs,w)
    return w
end

function init_weight(hidden_size)
    w = Any[]
    for h in hidden_size
        push!(w, randn(h[2],h[1]))
        push!(w, randn(h[2],1))
    end
    
    return w
end

report(epoch,w)=println((:epoch,epoch,:trn,loss2(w,xtrn,ytrn)))

report (generic function with 1 method)

In [88]:
using ProgressMeter

M = 2000
N = Int(M/50)
xtrn = randn(2,M)

2×2000 Array{Float64,2}:
 -0.30405   -0.370657  -1.52066    0.581724  …  0.704916  0.42309    -0.21182
 -0.721235  -0.437479  -0.847172  -1.15688      1.57007   0.0404583  -1.16424

In [89]:
scatter(xtrn[1,:], xtrn[2,:])

In [92]:
A = randn(2,2)
B = rand(2,1)

ytrn = sin.(A*tanh.(randn(2,M)))
m1 = maximum(abs.(ytrn[1,:]));
m2 = maximum(abs.(ytrn[2,:]));

ytrn[1,:] /= m1;
ytrn[2,:] /= m2;

In [93]:
scatter(ytrn[1,:], ytrn[2,:])

In [94]:
dtrn = minibatch(xtrn,ytrn,100;shuffle=true)

Knet.MB([-0.30405 -0.370657 … 0.42309 -0.21182; -0.721235 -0.437479 … 0.0404583 -1.16424], [-0.122399 0.217671 … 0.706631 -0.408539; 0.804608 0.706271 … -0.59212 0.0690813], 100, 2000, false, [1818, 1650, 1523, 1899, 1263, 957, 1129, 505, 626, 1568  …  394, 1461, 30, 820, 1135, 1847, 47, 1376, 762, 574], [2, 2000], [2, 2000], Array{Float64,2}, Array{Float64,2})

In [100]:
w = init_weight([(2,1024),(1024,2)])
w = train(w,xtrn,ytrn,1e-6,3000)

(:epoch, 10, :trn, 15.515874326396675)
(:epoch, 20, :trn, 15.454730782599468)
(:epoch, 30, :trn, 15.423628693066096)
(:epoch, 40, :trn, 15.445101652827676)
(:epoch, 50, :trn, 15.392532688721204)
(:epoch, 60, :trn, 15.348004323385789)
(:epoch, 70, :trn, 15.299781029461785)
(:epoch, 80, :trn, 15.219501462931014)
(:epoch, 90, :trn, 15.161451450996601)
(:epoch, 100, :trn, 15.129542525895303)
(:epoch, 110, :trn, 15.13815559379231)
(:epoch, 120, :trn, 15.14282468347751)
(:epoch, 130, :trn, 15.138507165677419)
(:epoch, 140, :trn, 15.104515348187281)
(:epoch, 150, :trn, 15.100914978600192)
(:epoch, 160, :trn, 15.098658135553615)
(:epoch, 170, :trn, 15.047228792565921)
(:epoch, 180, :trn, 15.032121766565632)
(:epoch, 190, :trn, 15.015763057150723)
(:epoch, 200, :trn, 15.036439619677044)
(:epoch, 210, :trn, 15.018047521948652)
(:epoch, 220, :trn, 14.998252600088735)
(:epoch, 230, :trn, 14.959758160166805)
(:epoch, 240, :trn, 14.940534030639146)
(:epoch, 250, :trn, 14.914173794898948)
(:epoch, 26

(:epoch, 2050, :trn, 12.885743060674178)
(:epoch, 2060, :trn, 12.884817588637707)
(:epoch, 2070, :trn, 12.890932948129873)
(:epoch, 2080, :trn, 12.875118117243119)
(:epoch, 2090, :trn, 12.866937664736708)
(:epoch, 2100, :trn, 12.871749964544836)
(:epoch, 2110, :trn, 12.8450696961506)
(:epoch, 2120, :trn, 12.820040010844357)
(:epoch, 2130, :trn, 12.802967699115968)
(:epoch, 2140, :trn, 12.768831909958697)
(:epoch, 2150, :trn, 12.728669029895995)
(:epoch, 2160, :trn, 12.73687932680503)
(:epoch, 2170, :trn, 12.741712436243187)
(:epoch, 2180, :trn, 12.73323058686655)
(:epoch, 2190, :trn, 12.701472495869481)
(:epoch, 2200, :trn, 12.677397836154908)
(:epoch, 2210, :trn, 12.677793495324511)
(:epoch, 2220, :trn, 12.659776689873596)
(:epoch, 2230, :trn, 12.661441758986685)
(:epoch, 2240, :trn, 12.625325542507658)
(:epoch, 2250, :trn, 12.62002736260387)
(:epoch, 2260, :trn, 12.619450343957826)
(:epoch, 2270, :trn, 12.609986666479731)
(:epoch, 2280, :trn, 12.612947049796372)
(:epoch, 2290, :trn, 

4-element Array{Any,1}:
 [-0.392536 1.75262; 0.186506 -0.551574; … ; 1.25282 0.0326088; -0.374833 -0.660809]
 [0.34521; 0.435482; … ; 1.01763; 0.846858]                                         
 [-1.18278 0.046375 … 1.7505 1.4055; -2.17113 -1.13487 … 1.34597 -1.96358]          
 [-1.44224; -1.37289]                                                               

In [39]:
ypred = predict(w,xtrn)
scatter(ypred[1,:],ypred[2,:])

In [None]:
sum(xtrn[:,1] .> xtrn[:,2])

$$f(x) = \sum_{n=0}^\infty f^{(n)}(0) \frac{x^n}{n!} $$

In [80]:
Pkg.update()

[1m[36mINFO: [39m[22m[36mUpdating METADATA...
[39m
[1m[36mINFO: [39m[22m[36mUpdating Images master...
[39m[1m[36mINFO: [39m[22m[36mComputing changes...
[39m[1m[36mINFO: [39m[22m[36mNo packages to install, update or remove
[39m