GAN with feature-matching generator loss function.

In [1]:
using Plots
plotly()
import Plots: plot
clibrary(:Plots)
using JLD
code_path = "../src/"
push!(LOAD_PATH, code_path)
using AnomalyDetection

In [2]:
# load data
dataset = load("toy_data_3.jld")["data"]
X = AnomalyDetection.Float.(dataset.data)
Y = dataset.labels
nX = X[:, Y.==0]
M, N = size(X)

(2, 103)

In [3]:
# GAN settings
zdim = 1 # code dimension
xdim = M # dimension of data
hiddendim = 32  

# setup the GAN model object
gsize = [zdim; hiddendim; hiddendim; xdim] # generator layout
dsize = [xdim; hiddendim*2; hiddendim*2; 1] # discriminator layout
lambda = 0.5 # anomaly score parameter in [0, 1]
# 1 - ignores the discriminator score
# 0- ignores the reconstruction error score
threshold = 0 # classification threshold, is recomputed (getthreshold or when using fit!)
contamination = size(Y[Y.==1],1)/size(Y, 1) # contamination ratio
L = 50 # batchsize
iterations = 10000 # no of iterations
cbit = 2500 # when should output be printed
verbfit = true # if output should be produced
pz = randn # code distribution (rand should also work)
activation = Flux.leakyrelu # should work better than relu
layer = Flux.Dense
rdelta = 1e-5 # stop training after this reconstruction error is achieved
# this parameter is basically useless in the case of GANs
alpha = 1.0 # weight of the classical generator loss in the total loss 
# used to train generator
Beta = 1.0 # for automatic threshold computation, in [0, 1] 
# 1.0 = tight around normal samples
tracked = true # do you want to store training progress?
# it can be later retrieved from model.traindata
model = fmGANmodel(gsize, dsize, lambda, threshold, contamination, L, iterations, cbit,
    verbfit, pz = pz, activation = activation, rdelta = rdelta, alpha = alpha,
    Beta = Beta, tracked = tracked, layer = layer)

AnomalyDetection.fmGANmodel(AnomalyDetection.fmGAN(Chain(Dense(1, 32, NNlib.leakyrelu), Dense(32, 32, NNlib.leakyrelu), Dense(32, 2)), Chain(Dense(1, 32, NNlib.leakyrelu), Dense(32, 32, NNlib.leakyrelu), Dense(32, 2)), Chain(Dense(2, 64, NNlib.leakyrelu), Dense(64, 64, NNlib.leakyrelu), Dense(64, 1, NNlib.σ)), Chain(Dense(2, 64, NNlib.leakyrelu), Dense(64, 64, NNlib.leakyrelu), Dense(64, 1, NNlib.σ)), randn), 0.5, 0, 0.1262135922330097, 50, 10000, 2500, true, 1.0f-5, 1.0f0, 1.0f0, MVHistory{ValueHistories.History})

In [4]:
# fit the model
Z = AnomalyDetection.getcode(model, size(nX, 2))
AnomalyDetection.evalloss(model, nX, Z)
AnomalyDetection.fit!(model, nX)
AnomalyDetection.evalloss(model, nX, Z)

discriminator loss: 0.6656314
feature-matching loss: 0.008114555
reconstruction error: 0.35584572

discriminator loss: 0.68620586
feature-matching loss: 0.0040919124
reconstruction error: 0.12610206

discriminator loss: 0.6870698
feature-matching loss: 0.004761703
reconstruction error: 0.14014542

discriminator loss: 0.6758728
feature-matching loss: 0.0051709283
reconstruction error: 0.13111414

discriminator loss: 0.7086017
feature-matching loss: 0.0048089647
reconstruction error: 0.117372625

discriminator loss: 0.7035248
feature-matching loss: 0.004312264
reconstruction error: 0.107279755



In [5]:
"""
	plot(model)

Plot the model loss.
"""
function plot(model::fmGANmodel)
	# plot model loss
	if model.history == nothing
		println("No data to plot, set tracked = true before training.")
		return
	else
        p = plot(model.history[:discriminator_loss], title = "model loss", 
            label = "discriminator loss", 
            xlabel = "iteration", ylabel = "loss", 
            seriestype = :line, 
            markershape = :none)
        plot!(model.history[:reconstruction_error], label = "reconstruction error",
            seriestype = :line, markershape = :none, title = "model loss")
        plot!(model.history[:generator_loss], label = "generator loss",
            seriestype = :line, markershape = :none, 
            c = :green,
            title = "model loss")
        plot!(model.history[:feature_matching_loss], label = "feature-matching loss",
            seriestype = :line, markershape = :none, title = "model loss")
        return p
    end
end


RecipesBase.plot

In [6]:
# plot model loss
display(plot(model))

if !isinteractive()
    gui()
end

In [7]:
# generate new data
Xgen = AnomalyDetection.generate(model, N)

2×103 Array{Float32,2}:
 0.754911  0.529348  0.0959175  0.76127  …  0.528229  0.053327  0.52425 
 0.819515  0.559883  0.965749   0.80772     0.558548  0.982901  0.589459

In [8]:
# plot them
xl = (minimum(X[1,:])-0.05, maximum(X[1,:]) + 0.05)
yl = (minimum(X[2,:])-0.05, maximum(X[2,:]) + 0.05)
p = scatter(nX[1,:], nX[2,:], title = "discriminator contours",
    xlims = xl, ylims = yl, label = "data")
scatter!(p, Xgen[1,:], Xgen[2,:], label = "generated data", legend = (0.1, 0.8))

x = linspace(xl[1], xl[2], 30)
y = linspace(yl[1], yl[2], 30)
zz = zeros(size(y,1),size(x,1))
for i in 1:size(y, 1)
    for j in 1:size(x, 1)
        zz[i,j] = AnomalyDetection.discriminate(model, AnomalyDetection.Float.([x[j], y[i]]))[1]
    end
end
contourf!(x, y, zz, c = :viridis)

display(p)
if !isinteractive()
    gui()
end

Classification is based on anomaly score computed as a weighted average of reconstruction error and discriminator score
\begin{equation}
A(x) = (1-\lambda)D(x)+\lambda||x-G(z)||_{2, z \sim p(z)}
\end{equation}

In [9]:
# predict labels
AnomalyDetection.setthreshold!(model, X)
tryhat = AnomalyDetection.predict(model, X)

103-element Array{Int64,1}:
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 ⋮
 1
 1
 1
 0
 1
 1
 1
 1
 1
 1
 1
 1

In [10]:
# get all the labels
model.lambda = lambda
AnomalyDetection.setthreshold!(model, X)
tryhat, tstyhat, _, _ = AnomalyDetection.rocstats(dataset, dataset, model);


 Training data performance: 
MLBase.ROCNums{Int64}
  p = 13
  n = 90
  tp = 13
  tn = 88
  fp = 2
  fn = 0
precision: 0.8666666666666667
f1score: 0.9285714285714286
recall: 1.0
false positive rate: 0.022222222222222223
equal error rate: 0.011111111111111112

 Testing data performance: 
MLBase.ROCNums{Int64}
  p = 13
  n = 90
  tp = 12
  tn = 89
  fp = 1
  fn = 1
precision: 0.9230769230769231
f1score: 0.9230769230769231
recall: 0.9230769230769231
false positive rate: 0.011111111111111112
equal error rate: 0.04401709401709402


In [11]:
# plot heatmap of the fit
xl = (minimum(X[1,:])-0.05, maximum(X[1,:]) + 0.05)
yl = (minimum(X[2,:])-0.05, maximum(X[2,:]) + 0.05)
p = scatter(X[1, tryhat.==1], X[2, tryhat.==1], c = :red, label = "predicted positive",
    xlims=xl, ylims = yl, title = "classification results")
scatter!(X[1, tryhat.==0], X[2, tryhat.==0], c = :green, label = "predicted negative",
    legend = (0.7, 0.7))

x = linspace(xl[1], xl[2], 30)
y = linspace(yl[1], yl[2], 30)
zz = zeros(size(y,1),size(x,1))
for i in 1:size(y, 1)
    for j in 1:size(x, 1)
       zz[i,j] = AnomalyDetection.anomalyscore(model, AnomalyDetection.Float.([x[j], y[i]]))
    end
end
contourf!(x, y, zz, c = :viridis)

display(p)
if !isinteractive()
    gui()
end

In [12]:
# plot the roc curve as well
ascore = AnomalyDetection.anomalyscore(model, X);
recvec, fprvec = AnomalyDetection.getroccurve(ascore, Y)

([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  0.692308, 0.615385, 0.538462, 0.461538, 0.384615, 0.307692, 0.230769, 0.153846, 0.0769231, 0.0769231], [1.0, 0.988889, 0.977778, 0.966667, 0.955556, 0.944444, 0.933333, 0.922222, 0.911111, 0.9  …  1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, 1.53003e-15, -0.0111111])

In [13]:
function plotroc(args...)
    # plot the diagonal line
    p = plot(linspace(0,1,100), linspace(0,1,100), c = :gray, alpha = 0.5, xlim = [0,1],
    ylim = [0,1], label = "", xlabel = "false positive rate", ylabel = "true positive rate",
    title = "ROC")
    for arg in args
        plot!(arg[1], arg[2], label = arg[3], lw = 2)
    end
    return p
end

plargs = [(fprvec, recvec, "fmGAN")]
display(plotroc(plargs...))
if !isinteractive()
    gui()
end



In [14]:
# plot EER for different settings of lambda
using MLBase: roc, correctrate, precision, recall, f1score, false_positive_rate, false_negative_rate
n = 21
lvec = linspace(0,1,n)
eervec = zeros(n)
for i in 1:n
    model.lambda = lvec[i]
    AnomalyDetection.setthreshold!(model, X)
    tryhat, tsthat, trroc, tstroc = AnomalyDetection.rocstats(dataset.data, dataset.labels,
        dataset.data, dataset.labels, model, verb = false)
    eervec[i] = (false_positive_rate(tstroc) + false_negative_rate(tstroc))/2
end

In [15]:
plot(lvec, eervec, title="equal error rate vs lambda",
    xlabel = "lambda",
    ylabel="EER")