# Results

In [1]:
# using Revise
using LinearAlgebra, Random
using StatsBase, Statistics
using Distributions, MultivariateStats   # Categorical, P(P)CA
using Quaternions    # For manipulating 3D Geometry
using MeshCat        # For web visualisation / animation
using PyPlot         # Plotting
using AxUtil         # Cayley, skew matrices
using Flux           # Optimisation
using DSP            # convolution / low-pass (MA) filter

# small utils libraries
using ProgressMeter, Formatting, ArgCheck, Dates
using BSON, NPZ

In [2]:
DIR_MOCAP_MTDS = "." 

# Data loading and transformation utils
include(joinpath(DIR_MOCAP_MTDS, "io.jl"))

# MeshCat skeleton visualisation tools
include(joinpath(DIR_MOCAP_MTDS, "mocap_viz.jl"))

# Data scaling utils
include(joinpath(DIR_MOCAP_MTDS, "util.jl"))

# Models: LDS
include(joinpath(DIR_MOCAP_MTDS, "models.jl"))

# Table visualisation
include(joinpath(DIR_MOCAP_MTDS, "pretty.jl"))

In [3]:
############################################
##    CUSTOM WIDELY USED FUNCTIONS
function zero_grad!(P) 
    for x in P
        x.grad .= 0
    end
end

const NoGradModels = Union{model.MyLDS_ng, model.ORNN_ng}
const _var_cache = IdDict()

mse(Δ::AbstractArray, scale=size(Δ, 1)) = mean(x->x^2, Δ)*scale

function mse(d::mocaputil.DataIterator, m::NoGradModels)
    obj = map(d) do (y, u, new_state)
        new_state && (m.h .= zeros(size(m, 1))) 
        mse(m(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


mse(Ds::Vector{D}, m::NoGradModels) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m)
mse(D::Dict, m::NoGradModels) = mse(m(D[:U]) - D[:Y])
mse(V::Tuple, m::NoGradModels) = mse(m(V[2]) - V[1])

# Calculate variance
function _calc_var!(cache::IdDict, d::mocaputil.DataIterator)
    Y = reduce(hcat, [y for (y, u, h) in d])
    _var_cache[d] = var(Y, dims=2)
end

function _calc_var!(cache::IdDict, d::Vector{D}) where {D <: Dict}
    Y = reduce(hcat, [dd[:Y] for dd in d])
    _var_cache[d] = var(Y, dims=2)
end

function Statistics.var(d::Union{mocaputil.DataIterator, Vector{D}}) where {D <: Dict}
    !haskey(_var_cache, d) && _calc_var!(_var_cache, d)
    return _var_cache[d]
end
Statistics.var(d::Dict) = var(d[:Y], dims=2)

# Standardised MSE
smse(Δ::AbstractArray, scale=size(Δ, 1)) = mse(Δ, scale) / sum(var(Δ, dims=2))

smse(d::mocaputil.DataIterator, m::NoGradModels) = mse(d, m) / sum(var(d))
smse(D::Dict, m::NoGradModels) = mse(m(D[:U]) - D[:Y]) / sum(var(D))
smse(Ds::Vector{D}, m::NoGradModels) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m) / sum(var(Ds))
smse(D::Tuple, m::NoGradModels) = mse(D, m) / sum(var(D[1], dims=2))

rsmse(args...) = sqrt(smse(args...))

In [4]:
function mse(d::mocaputil.DataIterator, m::model.MTLDS_ng, z::AbstractArray)
    @argcheck size(z, 2) == length(d)
    obj = map(enumerate(d)) do (ii, (y, u, new_state))
        new_state && (m.h .= zeros(size(m, 1))) 
        cmodel = model.make_lds(m, z[:,ii], m.η_h)
        mse(cmodel(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


function mse(d::mocaputil.DataIterator, m::model.ORNN_ng, z::AbstractArray, nn::Chain)
    @argcheck size(z, 2) == length(d)
    obj = map(enumerate(d)) do (ii, (y, u, new_state))
        new_state && (m.h .= zeros(size(m, 1))) 
        cmodel = model.make_rnn_psi(m, Tracker.data(nn(z[:,ii])), 1f0)
        mse(cmodel(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end

smse(d::mocaputil.DataIterator, m::model.MTLDS_ng, z::AbstractArray) = mse(d, m, z) / sum(var(d))
smse(d::mocaputil.DataIterator, m::model.ORNN_ng, z::AbstractArray, nn::Chain) = mse(d, m, z, nn) / sum(var(d))

In [5]:
# NOGRADMODELS FAILED ON RELOAD
mse(Ds::Vector{D}, m::model.ORNN_ng) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m)
mse(D::Dict, m::model.ORNN_ng) = mse(m(D[:U]) - D[:Y])
mse(V::Tuple, m::model.ORNN_ng) = mse(m(V[2]) - V[1])
function mse(d::mocaputil.DataIterator, m::model.ORNN_ng)
    obj = map(d) do (y, u, new_state)
        new_state && (m.h .= zeros(size(m, 1))) 
        mse(m(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end
smse(d::mocaputil.DataIterator, m::model.ORNN_ng) = mse(d, m) / sum(var(d))
smse(D::Dict, m::model.ORNN_ng) = mse(m(D[:U]) - D[:Y]) / sum(var(D))
smse(Ds::Vector{D}, m::model.ORNN_ng) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m) / sum(var(Ds))
smse(D::Tuple, m::model.ORNN_ng) = mse(D, m) / sum(var(D[1], dims=2))

### Load in Data
See `2_Preprocess.ipynb`

**Note that in the current harddisk state**,
* `edin_Ys_30fps.bson` was created with `include_ftcontact=false, fps=30`,
* `edin_Xs_30fps.bson` was created with `include_ftcontact=true, include_ftmid=true, joint_pos=false, fps=fps, speed=false`.

In [6]:
database = "../data/mocap/edin-style-transfer/"
files_edin = [joinpath(database, f) for f in readdir(database)];
style_name_edin = [x[1] for x in match.(r"\.\./[a-z\-]+/[a-z\-]+/[a-z\-]+/([a-z]+)_.*", files_edin)];
styles = unique(style_name_edin)
styles_lkp = [findall(s .== style_name_edin) for s in styles];

In [7]:
# Load in data
Usraw = BSON.load("edin_Xs_30fps_final.bson")[:Xs];
Ysraw = BSON.load("edin_Ys_30fps_final.bson")[:Ys];

In [8]:
# Ysraw = [y[2:end,:] for y in Ysraw]
# Usraw = [hcat(u[2:end,1:end-8], u[1:end-1,end-7:end]) for u in Usraw];

In [9]:
# Standardise inputs and outputs
standardize_Y = fit(mocaputil.MyStandardScaler, reduce(vcat, Ysraw),  1)
standardize_U = fit(mocaputil.MyStandardScaler, reduce(vcat, Usraw),  1)

Ys = [mocaputil.scale_transform(standardize_Y, y[2:end, :] ) for y in Ysraw];  # (1-step ahead of u)
Us = [mocaputil.scale_transform(standardize_U, u[1:end-1,:]) for u in Usraw];  # (1-step behind y)

@assert (let c=cor(Usraw[1][1:end-1, 4:end], Ysraw[1][2:end, :], dims=1); 
        !isapprox(maximum(abs.(c[.!isnan.(c)])), 1.0); end) "some input features perfectly correlated"

# to invert: `mocaputil.invert(standardize_Y, y)`

In [10]:
# SENSE CHECK
# check that no bugs in constructing U, Y (i.e. esp that t's align and can predict U --> Y)
let c=cor(reduce(vcat, Us) |>f64, reduce(vcat, Ys) |> f64, dims=1)
    imshow(c, aspect="auto")
    nonan_c = c[.!isnan.(c)]
    title(format("max (abs) corrcoeff: {:.8f}", maximum(abs.(nonan_c))))
    flush(stdout)
#     display(findmax(reshape(nonan_c, size(c, 1) - 2, size(c,2))))
#     printfmtln("10th best result {:.5f}", reverse(sort(nonan_c))[10]) 
end
colorbar()

In [11]:
file_offsets = BSON.load("smooth_offsets_per_file.bson")[:offsets];

In [12]:
expmtdata = mocapio.ExperimentData(Ysraw, [Matrix(y') for y in Ys], 
    [Matrix(u') for u in Us], styles_lkp);
# see ?mocapio.get_data

In [13]:
# Get training set for STL and pooled models.
style_ix = 1
train_ixs = setdiff(1:8, style_ix)
min_size = 63;
batch_size = 64;

trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled)
trainIter = mocaputil.DataIterator(trainPool, 64, min_size=min_size);
trainIters = collect(trainIter);

In [14]:
testIter = mocaputil.DataIterator(testPool, 64, min_size=min_size);
testIters = collect(testIter);

In [15]:
segment_lens = [length(mocaputil.DataIterator(mocapio.get_data(expmtdata, i, :train, :stl, split=[0.875,0.125]),
            64, min_size=63)) for i in train_ixs];
segment_lkp = [collect(i+1:j) for (i,j) in zip(vcat(0, cumsum(segment_lens[1:end-1])), cumsum(segment_lens))];
segment_names = ["angry", "childlike", "depressed", "neutral", "old", "proud", "sexy", "strutting"][train_ixs];
pretty.table(reshape(cumsum(vcat(1, segment_lens)[1:end-1]), :, 1), header_col=segment_names, dp=0)

In [16]:
unsqueeze(x, d) = reshape(x, (size(x)[1:d-1]..., 1, size(x)[d:end]...))

In [17]:
function data_ahead(dataIters, start_ix, k_ahead)
    reduce(hcat, [dataIters[i][1] for i in start_ix+1:start_ix+k_ahead]),
    reduce(hcat, [dataIters[i][2] for i in start_ix+1:start_ix+k_ahead])
end

In [18]:
# **** interacting directly with python codebase.
# **** In order to keep things simple, we'll just use the filesystem to pass data.
using PyCall
pysys = pyimport("sys")
pytorch = pyimport("torch")
pushfirst!(PyVector(pysys."path"), normpath(pwd(), "../human-motion-prediction-pytorch/src"));
pymt = pyimport("forjulia")   #HACK: most modules in dir have underscores => PyCall.jl doesn't import these :(

In [19]:
base_args_ = Dict{String, Any}("seq_length_out"=>64, "decoder_size"=>1024, "batch_size"=>16,
            "latent_k"=>3,
            "human_size"=>64, 
            "input_size"=>44,
            "style_ix"=>style_ix,
            "use_cpu"=>true, 
            "data_dir"=>".");

# =========== LEAVE BELOW ALONE => FORMATTING FOR INPUT ARGS AND CMDLINE =================
base_args = filter(x->!(x.second isa Bool) || x.second, base_args_)  # remove all "false" arguments (implicit)
base_args = Dict(k=> (v isa Bool && v == true) ? "" : v for (k,v) in base_args)  # replace all "true" args
base_args = [join(["--"*k, v], " ") for (k,v) in base_args]   # format for cmdline
base_args = filter(x->length(x)>0, reduce(vcat, split.(base_args, " ")));  

In [20]:
!(@isdefined orig_pystdout) && (orig_pystdout = pysys.stdout; pysys.stdout = stdout);

In [21]:
pyb = pyimport("runbmark")

##### Example load pytorch model

In [22]:
style_ix = 4
_mname = ["open_1_3000", "open_2_18000", "open_3_3000", "open_4_13000", 
        "open_5_6000", "open_6_4000", "open_7_15000", "open_8_6000"][style_ix]
load_args = []
push!(load_args, "--style_ix")
push!(load_args, string(7))
push!(load_args, "--load")
push!(load_args, "experiments/GRU/" * _mname)
push!(load_args, "--use_cpu")
load_args = pyb.parseopts4bmark.parse_args(load_args)
load_args = pyb.parseopts4bmark.initial_arg_transform(load_args)
mtorch_b = pyb.create_model(load_args, [1,2], false)
mtorch_b.eval();
mtorch_b.num_layers = 1   # HACK: wasn't in first version.
mtorch_b.target_seq_len = 196

In [23]:
function resample_output(Y, U; target_delta=0.03*2π)
    d, Yt = size(Y)
    Δs = mod2pi.(-diff(atan.(U[end-1, :], U[end, :])))
    cur_time = vcat(0, cumsum(Δs))
    N = Int(floor(cur_time[end]/target_delta))
    reqd_time = cumsum(vcat(0, repeat([target_delta], N)))
    c_old = 1
    
    Yrsmp = ones(eltype(Y), d, N) * NaN
    Ursmp = ones(eltype(U), size(U,1), N) * NaN
    for t in 1:N
        tt = reqd_time[t]
        while tt > cur_time[c_old+1]
            c_old += 1
        end
        if tt == 0 && c_old == 1
            Yrsmp[:, t] = Y[:, c_old]
            Ursmp[:, t] = U[:, c_old]
            continue
        end
        pct = (tt - cur_time[c_old])/(cur_time[c_old+1] - cur_time[c_old])
        Yrsmp[:, t] = Y[:, c_old]*(1-pct) + Y[:, c_old+1]*pct
        Ursmp[:, t] = U[:, c_old]*(1-pct) + U[:, c_old+1]*pct
    end
    return Yrsmp, Ursmp
end

In [319]:
length(mocaputil.DataIterator(mocapio.get_data(expmtdata, 1, :train, :stl, split=[0.8,0.0,0.2]),
    64, min_size=min_size))

In [318]:
length(trainIters[1])

In [322]:
trainIters = []
validIters = []
for style in 1:8
    trainPoolstyle = mocapio.get_data(expmtdata, style, :train, :stl, split=[0.8,0.0,0.2])
    validPoolstyle = mocapio.get_data(expmtdata, style, :test, :stl, split=[0.8,0.0,0.2])
    for x in trainPoolstyle
        x[:Y], x[:U] = resample_output(x[:Y], x[:U])
    end
    for x in validPoolstyle
        x[:Y], x[:U] = resample_output(x[:Y], x[:U])
    end
    trainIter = mocaputil.DataIterator(trainPoolstyle, 64, min_size=min_size);
    validIter = mocaputil.DataIterator(validPoolstyle, 64, min_size=min_size);
    push!(trainIters, collect(trainIter))
    push!(validIters, collect(validIter))
end

In [317]:
for style in 1:8
    trainIters[style] = filter(x->size(x[1], 2) == 64, trainIters[style])
    validIters[style] = filter(x->size(x[1], 2) == 64, validIters[style])
end

## Model

In [26]:
gru = GRU(67, 512)
fc1 = Chain(Dense(512, 300, relu), Dense(300, 8, identity))
gru.init.data .= 0

function _forwardlogit(Y::AbstractArray, gru, fc1)
    for t in 1:size(Y, 2)
        gru(Y[:,t,:])
    end
    return fc1(gru.state)
end

function forward(Y::AbstractArray, gru, fc1)
    logittarget_hat = _forwardlogit(Y, gru, fc1)
    return softmax(logittarget_hat)
end

In [28]:
opt = ADAM(1e-4)
ps = Flux.params(gru, fc1)
history = zeros(400, 2) * NaN

test_gru, test_fc1 = map(x->Flux.mapleaves(Tracker.data, x), (gru, fc1))

for i in 1:400
    Flux.reset!(gru)
    
    cY = map(1:8) do style
        unsqueeze(trainIters[style][rand(1:length(trainIters[style]))][1], 3)
    end |> x -> cat(x..., dims=3);
    
    target = Flux.onehotbatch(1:8, 1:8)
    logittarget_hat = _forwardlogit(cY, gru, fc1)
    loss = Flux.logitcrossentropy(logittarget_hat, target)
    
    Tracker.back!(loss)
    for p in ps
        Flux.Tracker.update!(opt, p, p.grad)
    end
    
    # Validation
    Flux.reset!(test_gru)
    validY = map(1:8) do style
        unsqueeze(validIters[style][rand(1:length(validIters[style]))][1], 3)
    end |> x -> cat(x..., dims=3);
    valid_logittarget_hat = _forwardlogit(validY, test_gru, test_fc1)
    valid_loss = Flux.logitcrossentropy(valid_logittarget_hat, target)
    
    if i % 5 == 0
        println(format("Loss at {:03d} is {:01.4f} ({:01.4f})", i, loss.data, valid_loss.data))
    end
    history[i, 1] = loss.data
    history[i, 2] = valid_loss.data
end

In [70]:
Flux.crossentropy(_pred', Flux.onehot(3, 1:8)')

In [93]:
Flux.reset!(gru)
forward(trainIters[4][72][1], gru, fc1)

In [76]:
_mce, _pred = forward(trainIters[3][72][1], Flux.onehot(3, 1:8), gru, fc1)

To do:
1. Verify optimisation is working.
2. Hold out validation set for training.
3. Experiment on style transfer data.

## Use model outputs

In [30]:
# style_ix = 4
# _mname = ["open_1_3000", "open_2_18000", "open_3_3000", "open_4_13000", 
#         "open_5_6000", "open_6_4000", "open_7_15000", "open_8_6000"][style_ix]
# load_args = []
# push!(load_args, "--style_ix")
# push!(load_args, string(7))
# push!(load_args, "--load")
# push!(load_args, "experiments/GRU/" * _mname)
# push!(load_args, "--use_cpu")
# load_args = pyb.parseopts4bmark.parse_args(load_args)
# load_args = pyb.parseopts4bmark.initial_arg_transform(load_args)
# mtorch_b = pyb.create_model(load_args, [1,2], false)
# mtorch_b.eval();
# mtorch_b.num_layers = 1   # HACK: wasn't in first version.
# mtorch_b.target_seq_len = 196

In [642]:
load_args = copy(base_args)
push!(load_args, "--load")
# push!(load_args, "experiments/bottleneck_16_1_lowlr_10000")
# push!(load_args, "experiments/dynamicsdict_256_20000")
# push!(load_args, "--dynamicsdict")
# push!(load_args, "experiments/final/k3_bottleneck24_2_prec9_mtrnn_lowlr20000")
# push!(load_args, format("experiments/final/k3_bottleneck24_{:d}_prec9_mtrnn_lowlr20000", style_ix))
# push!(load_args, format("experiments/biasonly/style{:d}_128_8e-4_k3_20000", style_ix))
push!(load_args, format("experiments/biasonly/style9_128_12e-4_k8_40000", style_ix))  # <= ***
# push!(load_args, format("experiments/nobias/style9_k8_30000")) # <= ***
# push!(load_args, format("experiments/final/k3_bottleneck24_{:d}_mtrnn_lowlr20000", style_ix))
# push!(load_args, "experiments/final/2d_embedding_model20000")
# push!(load_args, "experiments/final/k3_bottleneck16_1_mtrnn_lowlr20000")
load_args = pymt.parseopts.parse_args(load_args)
load_args = pymt.parseopts.initial_arg_transform(load_args)
mtorch = pymt.learn_mtfixbmodel.create_model(load_args, 850)
mtorch.eval();

In [242]:
allPools = mocaputil.DataIterator(mocapio.get_data(expmtdata, style_ix, :all, :pooled), 64, 
    min_size=min_size) |> collect;

In [233]:
chunk64(x) = [x[(1:64) .+ (i-1)*64] for i in 1:(length(x)÷64)]

In [278]:
# BSON.bson("exemplar_ixs.bson", ixs=[[1,55,72,88], [115,147,204,230], [247,269,287,318], [359,376,411,431], 
#         [487,503,540,573],[601,617,644,661], [691,716,778,823], [848,867,889,930]])

In [281]:
exemplar_ixs = BSON.load("exemplar_ixs.bson")[:ixs]

In [257]:
ixs_by_style = [i:j for (i,j) in zip(cumsum(vcat(1, ls[1:end-1])), cumsum(ls))]

In [256]:
ls = [114, 132, 112, 128, 114, 90, 157, 117];

In [283]:
ls = [length(filter(x->size(x[1], 2) == 64, collect(mocaputil.DataIterator(mocapio.get_data(expmtdata, style, 
                        :all, :stl), 64, min_size=63)))) for style in 1:8]
ls = [length(collect(mocaputil.DataIterator(mocapio.get_data(expmtdata, style, 
                        :all, :stl), 64, min_size=63))) for style in 1:8]

In [284]:
sum(ls)

In [171]:
[rand(i:j) for (i,j) in zip(cumsum(vcat(1, ls[1:end-1])), cumsum(ls))]

In [324]:
trainItersOrig = mocaputil.DataIterator(mocapio.get_data(expmtdata, 1, :all, :pooled), 
    64, min_size=min_size) |> collect;

In [488]:
3 .+ vcat(0, ls)

In [621]:
Zsn = []
scores = []
@showprogress for n = 1:20
    yhats = []
    score = []
    _Zixs = [rand(i:j-1) for (i,j) in zip(cumsum(vcat(1, ls[1:end-1])), cumsum(ls))] 
    for style_ix = 1:8
    #     _Zixs = [20, 188, 312, 424, 553, 670, 733, 880]
        for _batch_num in [3 + cumsum(vcat(0, ls))[style_ix]]   #exemplar_ixs[style_ix]
            _Yb, _Ub = trainItersOrig[_batch_num]
            yhat = []
            
            @pywith pytorch.no_grad() begin
                for j in 1:1
                    μprop = get(mtorch.mt_net.Z_mu, _Zixs[j] -1)
                    out = mtorch.forward(pytorch.tensor(reshape(_Ub', 1, :, 35)), 
                            pytorch.tensor(μprop), 
                            pytorch.tensor(repeat([1f-6],1,8)))[1].numpy()[1,:,:]
                    yhat_rsmp, _ = resample_output(out', _Ub)
                    push!(yhat, yhat_rsmp)
                end
            end
            push!(yhats, yhat)
            
            out = zeros(8,8)*NaN
            for j in 1:1
                Flux.reset!(test_gru)
#                 display(yhat[j])
                out[:,1] = forward(yhat[j], test_gru, test_fc1)[:]
            end
            push!(score, diag(out))
        end
    end
    push!(scores, mean(score))
    push!(Zsn, _Zixs)
end

In [622]:
Zsn[argmax(reduce(hcat, scores)[1,:])][1]

In [506]:
_Zs_mtbias2 = [reduce(hcat, Zsn)'[j, i] for (i, j) in 
        enumerate(mapslices(argmax, reduce(hcat, scores)', dims=1)[:])]

In [507]:
join(map(string, _Zs_mtbias2), ", ")

In [353]:
mapslices(argmax, reduce(hcat, scores)', dims=1)[:]

In [629]:
# _Zs_bias_only = [75, 209, 274, 408, 535, 619, 798, 947];
_Zs_bias_only = [94, 182, 288, 475, 568, 617, 794, 851];
_Zs_bias_only = [10, 186, 282, 376, 517, 675, 822, 876];
_Zs_bias_only_avg4 = [11, 186, 335, 425, 590, 617, 763, 963]
_Zs_mtbias = [50, 119, 293, 427, 554, 668, 825, 856];
# _Zs_mtbias = [64, 188, 248, 479, 555, 618, 746, 857];
# _Zs_mtbias = [5, 189, 322, 445, 525, 649, 701, 857];
_Zs_mtbias_avg4 = [48, 118, 341, 405, 553, 684, 762, 855];
_Zs_orig = [20, 188, 312, 424, 553, 670, 733, 880];

In [643]:
outs = []
_Zixs = _Zs_bias_only_avg4
avg4 = true   # true
@showprogress for style_ix = 1:8
    outs_batch = []
    batches = avg4 ? exemplar_ixs[style_ix] : [3 + cumsum(vcat(0, ls))[style_ix]]
    for _batch_num in batches

        _Yb, _Ub = trainItersOrig[_batch_num]
        yhat = []
        @pywith pytorch.no_grad() begin
            for j in 1:8
                μprop = get(mtorch.mt_net.Z_mu, _Zixs[j] -1)
                out = mtorch.forward(pytorch.tensor(reshape(_Ub', 1, :, 35)), 
                        pytorch.tensor(μprop), 
                        pytorch.tensor(repeat([1f-6],1,8)))[1].numpy()[1,:,:]
                yhat_rsmp, _ = resample_output(out', _Ub)
                push!(yhat, yhat_rsmp)
            end
        end

        out = zeros(8,8)*NaN
        for j in 1:8
            Flux.reset!(test_gru)
            out[:,j] = forward(yhat[j], test_gru, test_fc1)[:]
        end
        push!(outs_batch, out)
    end
    push!(outs, outs_batch)
end

In [584]:
fig, axs = subplots(2,4, figsize=(10,5))

for style_ix = 1:8
    out = mean(outs[style_ix])  # mean over the 4 batches
    axs[:][style_ix].imshow(out); #colorbar(axs[:][style_ix])
    axs[:][style_ix].set_title(style_ix)
end

In [368]:
styles = ["angry", "child", "depr.", "neut.", "old", "proud", "sexy", "strut."];
twodp(x) = format("{:.2f}", x)

In [645]:
ax = gca()

diag_color = ColorMap("Reds")(0.3)
cmap = ColorMap("Greys_r")

_payload = reduce(hcat, [diag(mean(outs[style_ix])) for style_ix in 1:8])
_payload = cmap(_payload)
diagixs = diagind(_payload[:,:,1])
_payload[diagixs] .= diag_color[1]
_payload[diagixs .+ 64] .= diag_color[2]
_payload[diagixs .+ 128] .= diag_color[3]

_im = ax.imshow(reverse(_payload, dims=1))

plt.setp(ax, yticks=0:7, yticklabels=reverse(styles))
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("source")
ax.set_ylabel("target")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for style_ix in 1:8, j in 1:8
    val = mean(outs[style_ix])[j, j]
    ax.text(style_ix-1, 8-j, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) +1)), fontsize=9)
end
# savefig("img/pub/style_tfer_matrix_avg4.pdf")
# savefig("img/pub/style_tfer_matrix_biasonly_avg4.pdf")

In [437]:
ax = gca()

_im = ax.imshow(reduce(hcat, [diag(mean(outs[style_ix])) for style_ix in 1:8]), cmap=ColorMap("Greys_r"))
# ax.set_xticklabels(0:8)
# ax.set_yticklabels(0:8)

plt.setp(ax, yticks=0:7, yticklabels=styles)
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("target")
ax.set_ylabel("source")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for style_ix in 1:8, j in 1:8
    val = mean(outs[style_ix])[j, j]
    ax.text(style_ix-1, j-1, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) -0.4)), fontsize=9)
end
# savefig("img/pub/style_tfer_matrix.pdf")

In [397]:
round.([mean(outs[i]) for i in 1:8][8], digits=2)

In [647]:
ax = gca()
_im = ax.imshow(reverse(mean(cat([mean(outs[i]) for i in 1:8]..., dims=3), dims=3)[:,:,1], dims=1), 
    cmap=ColorMap("Greys_r"))

plt.setp(ax, yticks=0:7, yticklabels=reverse(styles))
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("target")
ax.set_ylabel("predicted")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for i in 1:8, j in 1:8
    val = mean([mean(outs[style_ix])[j, i] for style_ix in 1:8])
    ax.text(i-1, 8-j, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) +1)), fontsize=9)
end
# savefig("img/pub/style_tfer_avg_confusion_biasonly_avg4.pdf")

In [164]:
ax = gca()
_im = ax.imshow(reduce(hcat, [diag(outs[style_ix]) for style_ix in 1:8]), cmap=ColorMap("Greys_r"))
# ax.set_xticklabels(0:8)
# ax.set_yticklabels(0:8)

plt.setp(ax, yticks=0:7, yticklabels=styles)
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("target")
ax.set_ylabel("source")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for style_ix in 1:8, j in 1:8
    val = outs[style_ix][j, j]
    ax.text(style_ix-1, j-1, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) +1)), fontsize=6)
end
savefig("img/pub/style_tfer_matrix_biasonly.pdf")

In [196]:
ax = gca()
_im = ax.imshow(reduce(hcat, [diag(outs[style_ix]) for style_ix in 1:8]), cmap=ColorMap("Greys_r"))
# ax.set_xticklabels(0:8)
# ax.set_yticklabels(0:8)

plt.setp(ax, yticks=0:7, yticklabels=styles)
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("target")
ax.set_ylabel("source")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for style_ix in 1:8, j in 1:8
    val = outs[style_ix][j, j]
    ax.text(style_ix-1, j-1, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) +1)), fontsize=6)
end
savefig("img/pub/style_tfer_matrix_biasonly2.pdf")

In [454]:
ax = gca()
_im = ax.imshow(mean(cat([mean(outs[i]) for i in 1:8]..., dims=3), dims=3)[:,:,1], 
    cmap=ColorMap("Greys_r"))

plt.setp(ax, yticks=0:7, yticklabels=styles)
plt.setp(ax, xticks=0:7, xticklabels=styles)

ax.set_xlabel("target")
ax.set_ylabel("predicted")
# plt.colorbar(_im,fraction=0.022, pad=0.04)

for i in 1:8, j in 1:8
    val = mean([mean(outs[style_ix])[j, i] for style_ix in 1:8])
    ax.text(i-1, j-1, twodp(val), ha="center", va="center", 
        color=ColorMap("Greens")(σ(20*(val-0.5) +1)), fontsize=9)
end
savefig("img/pub/style_tfer_avg_confusion_biasonly2_avg4.pdf")