In [None]:
module Balance
csize(a)=(ndims(a)==1 ? size(a) : size(a)[1:end-1])
csize(a,n)=tuple(csize(a)..., n) # size if you had n columns
clength(a)=(ndims(a)==1 ? length(a) : stride(a,ndims(a)))
ccount(a)=(ndims(a)==1 ? 1 : size(a,ndims(a)))
csub(a,i)=(ndims(a)==1 ? error() : sub(a, ntuple(i->(:), ndims(a)-1)..., i))
cget(a,i)=(ndims(a)==1 ? error() : getindex(a, ntuple(i->(:), ndims(a)-1)..., i))
cset!(a,x,i)=(ndims(a)==1 ? error() : setindex!(a, x, ntuple(i->(:), ndims(a)-1)..., i))
size2(y)=(nd=ndims(y); (nd==1 ? (length(y),1) : (stride(y, nd), size(y, nd)))) # size as a matrix
size2(y,i)=size2(y)[i]
 
macro repeat(n,ex)
    quote
        for t = 1:$n
            $(esc(ex))
        end
    end
end

 
using PyCall
@pyimport imblearn.combine as cb
@pyimport imblearn.over_sampling as os
 
function balance(x, y; imblearn = false)
    nx = ccount(x)
    xtrn = cget(x,1:4*nx÷5)
    ytrn = y[1:4*end÷5]
    xtst  = cget(x,4*nx÷5+1:nx)
    ytst  = y[4*end÷5+1:end]
 
    if imblearn == :SMOTE
        xtrn = reshape(xtrn,size2(xtrn))
        @repeat 3 begin
        SM = os.SMOTE(ratio = 1.0)
        smx, smy = SM[:fit_sample](xtrn', vec(ytrn))
        xtrn = smx'; ytrn = smy'
        end
        xtrn = reshape(xtrn,(csize(data)...,ccount(xtrn)))
    elseif imblearn == :SMOTETomek      
        xtrn = reshape(xtrn,size2(xtrn))
        @repeat 3 begin
        STK = cb.SMOTETomek(ratio = 1.0)
        stkx, stky = STK[:fit_sample](xtrn', vec(ytrn))
        xtrn = stkx'; ytrn = stky'
        end
        xtrn = reshape(xtrn,(csize(data)...,ccount(xtrn)))             
    end
 
    return xtrn, ytrn, xtst, ytst
end
 
function histlabel(label)
    if !isdefined(:nclass)
        @eval nclass = $(Int(maximum(label))+1)
    end
    freq =  hist(vec(label),-0.5:1:nclass-0.5)[2]
    prob = freq/sum(freq)
end
 
using Gadfly, DataFrames

function imbalance_view(ytrn, ytst, preds)
    xticks = Guide.xticks(ticks = [0:nclass-1;])
    yticks = Guide.yticks(ticks = [0:nclass-1;])
    df = DataFrame(
        class = ytst,
        predict = preds
    )
    plot(df, x = :class, y = :predict, Geom.violin, xticks, yticks,
    Guide.title("Error Violin")) |> display
    df = DataFrame(
        class = 0:nclass-1,
        percent = histlabel(ytst[ytst .!= preds])
    )
    plot(df, x = :class, y = :percent, Geom.bar, xticks,
    Guide.title("Error Contribution")) |> display
    df_trn = DataFrame(
        class = 0:nclass-1,
        percent = histlabel(ytrn),
        group = "train"
    )
    df_tst = DataFrame(
        class = 0:nclass-1,
        percent = histlabel(ytst),
        group = "test"
    )
    df_prd = DataFrame(
        class = 0:nclass-1,
        percent = histlabel(preds),
        group = "predict"
    )
    df = vcat(df_trn, df_tst, df_prd)
    plot(df, x = :class, y = :percent, color = :group,
    Geom.bar(position=:dodge), xticks,
    Guide.title("Distribution")) |> display
end

end # end of module Balance