In [1]:
using Pkg: @pkg_str
pkg"activate ."

In [2]:
using MLDatasets
using MLDataUtils
using Statistics
using Flux

┌ Info: Recompiling stale cache file /Users/oxinabox/.julia/compiled/v1.1/MLDatasets/9CUQK.ji for MLDatasets [eb30cadb-4394-5ae3-aed4-317e484a6458]
└ @ Base loading.jl:1184
┌ Info: Recompiling stale cache file /Users/oxinabox/.julia/compiled/v1.1/MLDataUtils/CQWB9.ji for MLDataUtils [cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d]
└ @ Base loading.jl:1184
┌ Info: Recompiling stale cache file /Users/oxinabox/.julia/compiled/v1.1/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1184


loaded


# FashionMNIST
https://github.com/zalandoresearch/fashion-mnist

![](https://raw.githubusercontent.com/zalandoresearch/fashion-mnist/master/doc/img/fashion-mnist-sprite.png)

| Label | Description |
| --- | --- |
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |


## As a programmer, I clearly just randomly grab two items of clothing out of my wardrobe without concern for if they would leaving me appropriately covered.

I would like a ML system that looks at images of my two items of clothing,
and tells me if they are going to cover me.
That means a dress, or a top and trousers.

In [3]:
function appropriately_dressed(clothing_items...)
    3 ∈ clothing_items && return true  # A dress is both top and bottoms
    
    # need trousers and a top of some kind
    return 1 ∈ clothing_items && length(intersect(clothing_items,(0, 2, 4, 6))) > 0 
end

appropriately_dressed (generic function with 1 method)

## Generate a dataset
We are going to generate a dataset,
by taking two copies of FashionMNIST,
shuffling them, and drawing them off in pairs,
labelling each with a suitable label.

In [4]:
function generate_pants_dataset(data, max_obs=Inf)
    combination_images = Vector{Vector{Float32}}()
    is_dressed_labels = Vector{Bool}()
    for (img1, lbl1) in eachobs(shuffleobs(data)), (img2, lbl2) in eachobs(shuffleobs(data))
        push!(combination_images, [img1[:]; img2[:]])
        push!(is_dressed_labels,  appropriately_dressed(lbl1, lbl2))
        if length(is_dressed_labels) >= max_obs
            break
        end
    end
    return combination_images, is_dressed_labels
end

generate_pants_dataset (generic function with 2 methods)

In [5]:
using Random
#Random.seed!(4)
Random.seed!(14)


train_data = generate_pants_dataset((FashionMNIST.traintensor(), FashionMNIST.trainlabels()), 20_000)
@show mean(last, eachobs(train_data))

test_data = generate_pants_dataset((FashionMNIST.testtensor(), FashionMNIST.testlabels()), 1_000);
@show mean(last, eachobs(test_data))

mean(last, eachobs(train_data)) = 0.49405
mean(last, eachobs(test_data)) = 0.489


0.489

# Construct a Balanced Training Set

In [6]:
balanced_train_data = undersample(last, train_data);

@show nobs(balanced_train_data)
@show mean(last, eachobs(balanced_train_data));

nobs(balanced_train_data) = 19762
mean(last, eachobs(balanced_train_data)) = 0.5


# Flux

In [39]:
function flux_model(in_size = 1568)
    leaky_relu6(x) = 0.01x + clamp(x, 0, 6)

    return Chain(
        Dense(in_size, 512, leaky_relu6), 
        Dense(512, 128, leaky_relu6),
        Dense(128, 64, leaky_relu6), 
        Dense(64, 1, σ),
        vec
    )
    
end

function demo_flux(train_data, test_data)
    mdl = flux_model()
    
    features = reduce(hcat, first(train_data))
    labels = Float32.(last(train_data))
    
    Flux.train!(
        params(mdl),
        Iterators.repeated((features, labels), 50), # 50 epochs
        Flux.ADAM()
    ) do xs, ys # This block repressents 1 epoch
        mean(Flux.binarycrossentropy.(ys, mdl(xs)))
    end
    
    test_features, test_labels = test_data
    
    probs = mapreduce(Flux.data∘mdl, vcat, test_features)
    classes = probs .> 0
    @show mean(classes)
    acc = mean(test_labels .== classes)
    @show acc
end

demo_flux (generic function with 1 method)

In [40]:
@time demo_flux(train_data, test_data)

ErrorException: Loss is NaN

In [None]:
@time demo_flux(balanced_train_data, test_data)

In [13]:
function trail_1(train_data)
    mdl = flux_model()
    features = first(train_data)
    map(mdl, features)
end

@time trail_1(train_data)
@time trail_1(train_data);

  1.610899 seconds (2.40 M allocations: 308.420 MiB, 12.70% gc time)
  1.578001 seconds (2.40 M allocations: 308.341 MiB)


In [34]:

function flux_model2(in_size = 1568)
    leaky_relu6(x) = 0.01x + clamp(x, 0, 6)

    return Chain(
        Dense(in_size, 512, leaky_relu6), 
        Dense(512, 128, leaky_relu6),
        Dense(128, 64, leaky_relu6), 
        Dense(64, 1, σ),
        vec
    )
end    

function trail_2(train_data)
    mdl = flux_model2()
    features = reduce(hcat, first(train_data))
    mdl(features)
end

@time trail_2(train_data)
@time trail_2(train_data)

  1.357463 seconds (819.18 k allocations: 390.562 MiB, 43.48% gc time)
  0.632675 seconds (239 allocations: 348.017 MiB, 64.92% gc time)


Tracked 20000-element Array{Float32,1}:
 0.4894056f0 
 0.50106543f0
 0.49604225f0
 0.4906915f0 
 0.50077826f0
 0.4745427f0 
 0.4569315f0 
 0.47123772f0
 0.445683f0  
 0.48552188f0
 0.47005388f0
 0.49978563f0
 0.44844544f0
 ⋮           
 0.43247086f0
 0.47136745f0
 0.5098584f0 
 0.50043267f0
 0.48318735f0
 0.48956248f0
 0.48548692f0
 0.44235912f0
 0.48430443f0
 0.51597124f0
 0.45009404f0
 0.4508467f0 

In [19]:
features = reduce(hcat, first(train_data))'

20000×1568 LinearAlgebra.Adjoint{Float32,Array{Float32,2}}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0313726  0.0        0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.705882   0.415686   0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0 