In [2]:
using Knet
# Test if Knet is using gpu
Knet.gpu()

└ @ CuArrays /kuacc/users/ssafadoust20/.julia/packages/CuArrays/A6GUx/src/CuArrays.jl:122


0

In [3]:
using Pkg;

# Install missing packages
for p in ["Knet", "MLJ", "MLJModels", "Distributions", "Plots"]
    if !haskey(Pkg.installed(),p)
        Pkg.add(p);
    end
end

using Knet, Plots, Random,  MLJ, Distributions, LinearAlgebra

In [93]:
atype = (Knet.gpu()>=0 ? Knet.KnetArray{Float32} : Array{Float32})

KnetArray{Float32,N} where N

In [94]:
struct Mask; d; reverse; end

#one argument: (mask) will return x_id and x_change
#two argument: (unmask) will_return concat(y_id, y_change)
function (mask::Mask)(x) 
    len = size(x, 1)
    b = convert(atype,zeros(len,1))
    d = mask.d
    if mask.reverse 
        b[d+1:end,1] .= 1
    else
        b[1:d,1] .= 1
    end
    x_id = x .* b
    x_change = x .* (1 .- b)
    return x_id, x_change
end
function (mask::Mask)(y_id, y_change)
    len = size(y_id, 1)
    b = convert(atype,zeros(len,1))
    d = mask.d
    if mask.reverse 
        b[d+1:end,1] .= 1
    else
        b[1:d,1] .= 1
    end
    return y_id .* b + y_change .* (1 .- b)
end


struct Sequential
    layers
    Sequential(layers...) = new(layers)
end
(s::Sequential)(x) = (for l in s.layers; x = l(x); end; x)

struct DenseLayer; w; b; f; end

DenseLayer(i::Int,o::Int, f=relu) = DenseLayer(param(o,i), param0(o), f)

(d::DenseLayer)(x) = d.f.(d.w * x .+ d.b)



#Coupling Layer
mutable struct CouplingLayer; st_net::Sequential; mask::Mask; logdet; end

function CouplingLayer(;in_dim::Int, hidden_dim::Int, num_layers::Int, mask::Mask)
    layers = []
    push!(layers, DenseLayer(in_dim, hidden_dim, relu))
    for layer in 1:num_layers
        push!(layers, DenseLayer(hidden_dim, hidden_dim, relu))
    end
    push!(layers, DenseLayer(hidden_dim, 2*in_dim, identity))
    st_net = Sequential(layers...)
    CouplingLayer(st_net, mask, 0.0)
end

function (cpl::CouplingLayer)(x)
    x_id, x_change, s, t = get_s_and_t(cpl, x)
#     y_change = x_change .* exp.(s) .+ t #in original code, first addition is performed, then exponentiation
    y_change = (x_change .+ t) .* exp.(s) 
    y_id = x_id
    cpl.logdet = sum(s; dims=1)
    return cpl.mask(y_id, y_change)
end
#st is a neural network, the first part of the output is used as s, second part as t
function get_s_and_t(cpl::CouplingLayer, x)
    x_id, x_change = cpl.mask(x)
    st = cpl.st_net(x_id)
    middle = (size(st)[1]+1)÷2
    s, t = st[1:middle,:], st[middle+1:end,:]
    s = tanh.(s)
    return (x_id, x_change, s, t)
end


struct RealNVP; seq::Sequential; end

function RealNVP(;in_dim::Int, hidden_dim::Int, num_coupling_layers::Int, num_hidden_layers::Int)
    coupling_layers = []
    for i in 1:num_coupling_layers
        push!(coupling_layers, CouplingLayer(;in_dim=in_dim, hidden_dim=hidden_dim, num_layers=num_hidden_layers, mask=Mask(div(in_dim,2), Bool((i+1) %2))))
    end
    seq = Sequential(coupling_layers...)
    RealNVP(seq)
end

(realnvp::RealNVP)(x) = realnvp.seq(x)

function logdet(realNVP::RealNVP)
    total_logdet = 0.0
    for cpl in realNVP.seq.layers
        total_logdet = total_logdet .+ cpl.logdet
    end
    return total_logdet
end

logdet (generic function with 1 method)

In [95]:
function make_moons_ssl()
    Knet.seed!(2020)
    Random.seed!(2020)
    n_samples = 1000
    data = MLJ.make_moons(n_samples;noise=0.05)
    data = convert(atype, permutedims(hcat(data[1][1], data[1][2])))
    labels = convert(atype, ones(1,n_samples)) * (-1)
    idx1 = [1 2 4 5 6]
    labels[idx1] .= 1
    idx0 = [3 7 8 11 18]
    labels[idx0] .= 0
    return data, labels
end

make_moons_ssl (generic function with 1 method)

In [96]:
function mylogpdf(g,x)
    xx = convert(Array{Float32}, Knet.value(x))
    ans = Distributions.logpdf(g, xx)
    return convert(atype, ans)
end

function mygradlogpdf(g,x)
    xx = convert(Array{Float32}, Knet.value(x))
    ans = []
    for i in 1:size(xx,2)
        push!(ans, Distributions.gradlogpdf(g,xx[:,i]))
    end
    ans = hcat(ans...)
    return convert(atype, ans) 
end

@Knet.primitive mylogpdf(g,x),dy 1 reshape(dy, (1,length(dy))).*mygradlogpdf(g,x) 

In [97]:
struct Prior; means; n_components; d; gaussians; weights; end
#n_components: number of classes
#d: feature dimenstion of data points
#means: d x n_components
#gaussians: we have n_components multivariate-gaussians, each with size d
function Prior(means)
    d, n_components = size(means)
    weights = convert(atype, ones(1, n_components))
    gaussians = []
    for i in 1:n_components
        mu = means[:,i]
        sig = Matrix{Float64}(I, d, d)
        push!(gaussians, MvNormal(mu, sig))
    end
    Prior(means, n_components, d, gaussians, weights)
end

function log_prob(prior::Prior, z, labels=nothing; label_weight=1.0)
    all_log_probs = []
    for g in prior.gaussians
        push!(all_log_probs, mylogpdf(g, z))
    end
    all_log_probs = hcat(all_log_probs...) #n_instances x n_components
    mixture_log_probs = logsumexp(all_log_probs .+ log.(softmax(prior.weights)); dims=2)
    if labels == nothing
        return mixture_log_probs
    else
        #log_probs = convert(atype, zeros(size(mixture_log_probs)))
        len = size(mixture_log_probs, 1)
        int_labels = permutedims(convert(Array{Int32}, labels))
        c_mixture = convert(atype, zeros(len,1))
        mask_mixture = [index[1] for index in findall(label->label==-1, int_labels)]
        c_mixture[mask_mixture,1] .= 1
        log_probs = c_mixture .* mixture_log_probs
        for i in 1:prior.n_components
            c_all_log_probs = convert(atype, zeros(len,1))
            mask = [index[1] for index in findall(label->label==(i-1), int_labels)]
            c_all_log_probs[mask] .=  label_weight
            log_probs += (c_all_log_probs .* all_log_probs[:,i:i])
        end  
        return log_probs
    end
end

log_prob (generic function with 2 methods)

In [98]:
function flow_loss(z, logdet, labels, prior; k=256)
    prior_ll = log_prob(prior, z, labels)
    #I dont know why we are doing this correction
    batch_size = size(z,2)
    kk = length(z) / batch_size
    
    corrected_prior_ll = prior_ll .- log(k) * kk
    if logdet == 0
        ll = corrected_prior_ll
    else
        ll = corrected_prior_ll + permutedims(logdet)
    end
    nll = -mean(ll)
    return nll
end

flow_loss (generic function with 1 method)

In [99]:
function forward(realnvp, data, labels, prior)
    z = realnvp(data)
    sldj = logdet(realnvp)
    return flow_loss(z, sldj, labels, prior)
end

forward (generic function with 1 method)

In [100]:
data, labels = make_moons_ssl()
prior = Prior([-3.5 3.5; -3.5 3.5])
realnvp = RealNVP(in_dim=2, hidden_dim=512, num_coupling_layers=5, num_hidden_layers=1)
loss = @diff forward(realnvp, data, labels, prior)

T(23.626392)

In [102]:
data, labels = make_moons_ssl()
prior = Prior([-3.5 3.5; -3.5 3.5])
realnvp = RealNVP(in_dim=2, hidden_dim=512, num_coupling_layers=5, num_hidden_layers=1)

lr = 1e-4
epochs = 2001

num_unlabeled = Int(sum(labels .== -1))
num_labeled = size(labels)[2] - num_unlabeled
batch_size = num_labeled
print_freq = 500

int_labels = convert(Array{Int32}, labels)

mask_labeled = [index[2] for index in findall(label->label!=-1, int_labels)]
labeled_data = data[:,mask_labeled]
labeled_labels = labels[mask_labeled]

mask_unlabeled = [index[2] for index in findall(label->label==-1, int_labels)]
unlabeled_data = data[:, mask_unlabeled]
unlabeled_labels = labels[mask_unlabeled]

for p in Knet.params(realnvp)
    p.opt = Adam(;lr=lr)
end
for epoch in 1:epochs
    batch_idx = Distributions.sample(1:num_unlabeled, batch_size, replace=true)
    batch_x, batch_y = unlabeled_data[:, batch_idx], unlabeled_labels[batch_idx]
    batch_x = hcat(batch_x, labeled_data)
    batch_y = vcat(batch_y, labeled_labels)
    batch_y = reshape(batch_y, (1, size(batch_y)[1]))

    loss = @diff forward(realnvp, batch_x, batch_y, prior)
     
    for p in Knet.params(realnvp)
        g = Knet.grad(loss, p)
        update!(Knet.value(p), g, p.opt)
    end
    if epoch % print_freq == 1
        print("iter ")
        print(epoch)
        print(" loss: ")
        print(loss)
        println(" ")
    end
    if epoch == Int(floor(epochs * 0.5)) || epoch == Int(floor(epochs * 0.8))
        lr /= 10
        for p in Knet.params(realnvp)
            p.opt = Adam(;lr=lr)
        end
    end
end

iter 1 loss: T(25.00743) 
iter 501 loss: T(9.406581) 
iter 1001 loss: T(8.093089) 
iter 1501 loss: T(7.95645) 
iter 2001 loss: T(8.393528) 
