In [31]:
using Knet

function sigmoid(z) 
    return 1.0/(1.0+exp(-z)) 
end

function predict(w,x)
    x = tanh.(w[1]*mat(x) .+ w[2])
    x = (sum(x,1))
    return x
end

function loss(w,x,y)
    ypred = predict(w,x)
    l = sum((y-ypred).^2)/length(y)
    return l
end

lossgradient = grad(loss)

function train(w, dtrn, mu=1e-2, epochs=10)
    p = optimizers(w,Nesterov; lr = mu)
    for epoch=1:epochs
        for (x,y) in dtrn
            g = lossgradient(w, x, y)
            update!(w,g,p)
        end
        if epoch%500 == 0
            report(epoch,w)
        end
    end
    report(epochs,w)
    return w
end

function init_weight(hidden_size)
    w = Any[]
    for h in hidden_size
        push!(w, randn(h[2],h[1]))
        push!(w, randn(h[2],1))
    end
    
    return w
end

report(epoch,w)=println((:epoch,epoch,:trn,loss(w,xtrn,ytrn)))

report (generic function with 1 method)

In [7]:
using JLD, Interact, Plots

In [22]:
D = load("./../GAN/PCA0.jld");
Y = D["Y"];

In [23]:
D = load("./../GAN/ICAb476.jld");
W = D["W"];
X = pinv(W)*Y;

In [30]:
@manipulate for k = 1:12
    histogram(X[k,:])
end

In [34]:
y = zeros(X)
x = randn(size(X))
c = zeros(476)
for i = 1:476
    x[i,:] = sort(x[i,:])
    y[i,:] = sort(X[i,:])
    c[i] = maximum(abs.(y[i,:]))
    y[i,:] = y[i,:]/c[i]
end


In [37]:
weights = Any[]

0-element Array{Any,1}

In [38]:
xtrn = x[[1],:]
ytrn = y[[1],:]
for i = 1:12
    xtrn = x[[i],:]
    ytrn = y[[i],:]
    w = init_weight([(1,96)])
    dtrn = minibatch(xtrn,ytrn,100;shuffle=true)
    w = train(w,dtrn,1e-2,3000)
    push!(weights,w)
end

(:epoch, 500, :trn, 4.665786274110742e-5)
(:epoch, 1000, :trn, 4.272051662304153e-5)
(:epoch, 1500, :trn, 4.122335514232997e-5)
(:epoch, 2000, :trn, 4.021969189504804e-5)
(:epoch, 2500, :trn, 3.940872104880069e-5)
(:epoch, 3000, :trn, 3.8717232576389214e-5)
(:epoch, 3000, :trn, 3.8717232576389214e-5)
(:epoch, 500, :trn, 4.220544422561532e-5)
(:epoch, 1000, :trn, 2.6857887983539032e-5)
(:epoch, 1500, :trn, 2.259688775849973e-5)
(:epoch, 2000, :trn, 2.120998663786717e-5)
(:epoch, 2500, :trn, 2.0597731370200546e-5)
(:epoch, 3000, :trn, 2.0214610668987074e-5)
(:epoch, 3000, :trn, 2.0214610668987074e-5)
(:epoch, 500, :trn, 3.834332076956901e-5)
(:epoch, 1000, :trn, 3.0693608696804724e-5)
(:epoch, 1500, :trn, 2.682701338101078e-5)
(:epoch, 2000, :trn, 2.4654021714584427e-5)
(:epoch, 2500, :trn, 2.3318116317314602e-5)
(:epoch, 3000, :trn, 2.2419038770754907e-5)
(:epoch, 3000, :trn, 2.2419038770754907e-5)
(:epoch, 500, :trn, 9.408740206025577e-5)
(:epoch, 1000, :trn, 5.93419701877939e-5)
(:epo

In [27]:
size(weights)

(12,)

In [40]:
jldopen("w2.jld", "w") do file
    write(file,"w",weights)
    write(file,"c",c)
end

In [39]:
@manipulate for k=1:12
    w = weights[k]
    x = randn(1,1000)
    size(w)
    yg = c[k]*predict(w,x)
    histogram(yg[1,:])
end