---
jupyter: julia-1.10
---






# Classifying wings

![](wings/Asilidae/Asilidae%2011.png)

In this quick lesson, we try to classify wings using the tools seen in the last lessons.


In [None]:
using MetricSpaces
using Images
using DataFrames
using CairoMakie
using Ripserer, PersistenceDiagrams
import Plots
using ProgressMeter

using Clustering
import StatsPlots

import StatisticalMeasures.ConfusionMatrices as CM
using MultivariateStats
using Chain
using ImageFiltering

We prepare a dataframe with the files and classes of each image


In [None]:
ds = DataFrame()

for (root, dir, files) in walkdir("wings/")

    for file in files
        dc = Dict(:Classe => root |> basename, :Caminho => file, :Caminho_completo => joinpath(root, file))
        
        push!(ds, dc, cols = :union)
    end
end

ds;

In [None]:
ds_split = groupby(ds, :Classe) |> collect;

In [None]:
function plot_mosaic(s)
    mosaicview(
    [imresize(load(f), (150, 300)) for f ∈ s.Caminho_completo[1:min(end, 21)]]
    , ncol = 3
    ,fillvalue = RGB24(1)
    )
end;

f = ds.Caminho_completo[1]

## The dataset

The dataset consists of several images of 3 different species of insects:


### Asilidae


In [None]:
plot_mosaic(ds_split[1])

### Ceratopogonidae


In [None]:
plot_mosaic(ds_split[2])

### Tipulidae


In [None]:
plot_mosaic(ds_split[3])

We load all images as matrices


In [None]:
images = [load(img) .|> Gray |> channelview for img ∈ ds.Caminho_completo];

We can see that the image is indeed correct:


In [None]:
heatmap(images[1])

## Matrix to \mathbb{R}^2

As before, we need to transform each image in points of the plane.


In [None]:
function img_to_points(img)
    img2 = imfilter(img, Kernel.gaussian(1)) .|> float
    ids = findall(x -> x <= 0.8, img2)
    pts = getindex.(ids, [1 2])

    [ [ p[1], p[2] ] for p in eachrow(pts)] |> EuclideanSpace
end;

We convert each image to points 


In [None]:
pts = img_to_points.(images);

and normalize the coordinates, since each image has a different size:


In [None]:
function normalize!(pts)
    a, b = extrema(pts .|> last)

    pts ./ (b - a)
end

wings = normalize!.(pts);

We can plot a scatter to check that it is indeed ok:


In [None]:
scatter(wings[1])

In order to apply the Vietoris-Rips filtration, we need to reduce the amount of points in each wing. The farthest point sample come in our rescue again!


In [None]:
wings_short = @showprogress map(wings) do w
    ids = farthest_points_sample(w, 400)
    w[ids]
end;

Now we calculate each barcode using the Vietoris-Rips filtration:


In [None]:
pds = @showprogress map(wings_short) do w
    ripserer(w, cutoff = 0.01)
end

We can now see the metric space


In [None]:
scatter(wings_short[1])

and the corresponding 1-dimensional persistente diagram


In [None]:
Plots.plot(pds[1][2])

Now we calculate the pairwise 1-dimensional bottleneck distance between each wing:


In [None]:
function barcode_to_distance(pds)
    n = length(pds)
    DB = zeros(n, n)

    @showprogress for i ∈ 1:n
        for j ∈ i:n
            if i == j
                DB[i, j] = 0 
                continue 
            end

            DB[i, j] = Bottleneck()(pds[i][2], pds[j][2])
            DB[j, i] = DB[i, j]
        end
    end

    DB
end

In [None]:
DB = barcode_to_distance(pds)

and see if the classes are well separated:


In [None]:
function mds_plot(D)
    M = fit(MDS, D; distances = true, maxoutdim = 2)
    Y = predict(M)

    ds.Row = 1:nrow(ds)

    dfs = @chain ds begin
        groupby(:Classe)
        collect
    end

    fig = Figure();
    ax = Makie.Axis(fig[1,1])

    colors = cgrad(:tableau_10, 8, categorical = true)

    for (i, df) ∈ enumerate(dfs)    
        scatter!(
            ax, Y[:, df.Row]
            , label = df.Classe[1], markersize = 15
            , color = colors[i]
            )
    end

    axislegend();
    fig

    fig
end

In [None]:
mds_plot(DB)

## Slicing it sideways

As we did with the hand-written digits dataset, we can do some sideways slicing on the wings.


In [None]:
set_value(x, value) = x < 0.99 ? value : x

function side_filtration(img, axis = 1, invert = false)

    m = imfilter(img, Kernel.gaussian(1))
    # m = img .|> float
    m = set_value.(m, 0)
    # m |> image
    # m = img .|> float

    pts = img_to_points(m)

    a, b = if axis == 1 
        extrema(pts .|> first)
        else
        extrema(pts .|> last)
    end

    for i ∈ a:b

        v = (b - i) / (b - a)

        if invert == true
            v = 1.0 - v
        end

        if axis == 1
            m[i, :] = set_value.(m[i, :], v)
        else 
            m[:, i] = set_value.(m[:, i], v)
        end

    end

    m .|> float
end;

We can visualize the filtrations as follows:


In [None]:
img = images[5]
img2 = side_filtration(img, 1)
heatmap(img2)

In [None]:
img2 = side_filtration(img, 2)
heatmap(img2)

In [None]:
img2 = side_filtration(img, 1, true)
heatmap(img2)

In [None]:
img2 = side_filtration(img, 2, true)
heatmap(img2)

And calculate each barcode:


In [None]:
pds_x = @showprogress map(images) do img
    img2 = side_filtration(img)
    bc = ripserer(Cubical(img2), cutoff = 0.1)
end

pds_y = @showprogress map(images) do img
    img2 = side_filtration(img, 2)
    ripserer(Cubical(img2), cutoff = 0.1)
end

pds_x2 = @showprogress map(images) do img
    img2 = side_filtration(img, 1, true)
    ripserer(Cubical(img2), cutoff = 0.1)
end

pds_y2 = @showprogress map(images) do img
    img2 = side_filtration(img, 2, true)
    ripserer(Cubical(img2), cutoff = 0.1)
end

The respective distance matrices are obtained with


In [None]:
DB_x = barcode_to_distance(pds_x)
DB_y = barcode_to_distance(pds_y)
DB_x2 = barcode_to_distance(pds_x2)
DB_y2 = barcode_to_distance(pds_y2)

And we can see that none of the tools we used before can separate well the classes:


In [None]:
mds_plot(DB)

In [None]:
mds_plot(DB_x)

In [None]:
mds_plot(DB_y)

In [None]:
mds_plot(DB_x2)

In [None]:
mds_plot(DB_y2)

Even if we sum all these distances, we still can't cluster correctly any class:


In [None]:
DB_final = zero(DB)

for d in [DB, DB_x, DB_y, DB_x2, DB_y2]
    DB_final = DB_final + (d ./ maximum(d))
end

In [None]:
mds_plot(DB_final)