In [1]:
using Flux
using Metal
using CSV
using DataFrames
using OneHotArrays
using Statistics


In [2]:
Metal.functional()
device = Flux.get_device(; verbose=true)
device.deviceID


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mUsing backend: Metal.


<AGXG14GDevice: 0x1293cfa00>
    name = Apple M2

In [3]:
# load breast cancer dataset
# csv_file_path = metal_dir * "data/breast-cancer-wisconsin.data"
csv_file_path = "data/breast-cancer-wisconsin.data"

"data/breast-cancer-wisconsin.data"

In [4]:
df_orig = CSV.File(csv_file_path) |> DataFrame;
column_names = [:SampleCodeNumber, :ClumpThickness, :UniformityOfCellSize,
:UniformityOfCellShape, :MarginalAdhesion, :SingleEpithelialCellSize,
:BareNuclei, :BlandChromatin, :NormalNucleoli, :Mitoses, :Class];
rename!(df_orig, column_names);

In [5]:
# # Remove unwanted columns
df = df_orig[:, Not(Cols(:BareNuclei, :SampleCodeNumber))];

In [6]:
# Define a dictionary to map values
class_mapping = Dict(2 => 0, 4 => 1)
# Use the map function to apply the mapping
df[!, :Class] = map(
    x -> class_mapping[x], 
    df[:, :Class]
);

In [7]:
model = Chain(
    Dense(8 => 2),  # model with 8 features and 2 classes
    σ               # sigmoid activation
) |> device         # output model to GPU

Chain(
  Dense(8 => 2),                        [90m# 18 parameters[39m
  NNlib.σ,
) 

In [8]:
classes = [0, 1]
accuracy(x, y) = mean(onecold(model(x), classes) .== y);
# accuracy(x, y_orig)

In [9]:
# function loss(ŷ, y)
#     Flux.logitcrossentropy(ŷ, y)
# end;

function loss(ŷ, y)
    Flux.binarycrossentropy(ŷ, y)
end;

In [10]:
X_train = Matrix{Float64}(df[:, Not(:Class)])
X_train = transpose(X_train)
y_train = vec(df[:, :Class]);

In [11]:
optimizer = Flux.setup(Adam(), model)
train_loader = Flux.DataLoader((X_train, y_train), batchsize=1, shuffle=false)

698-element DataLoader(::Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}, Vector{Int64}})
  with first element:
  (8×1 Matrix{Float64}, 1-element Vector{Int64},)

In [12]:
epochs = 1;

In [13]:
for epoch in 1:epochs
    total_accuracy = 0.
    total_loss = 0.
    for (x_cpu, y_cpu) in train_loader
        # pass the input and label to the GPU
        input = reshape(x_cpu, :) |> device
        label = y_cpu[1] |> device
        # calculate the gradient
        ∂f∂x = gradient(m -> loss(m(input), label), model)
        Flux.update!(optimizer, model, ∂f∂x[1])
        # calculate accuracy and loss
        ŷ = model(input)
        total_accuracy += accuracy(input, label)
        total_loss += loss(ŷ, label)
    end
    s = length(train_loader)
    avg_accuracy = total_accuracy / s
    avg_loss = total_loss / s
    println("Epoch $epoch: Accuracy=$avg_accuracy, Loss=$avg_loss")
end

Epoch 1: Accuracy=0.5859598853868195, Loss=0.8969378419307843


<img src=images/fluxnn.png width='50%' height='50%' > </img>

# &#x1F4DA; References

## Running the `fluxnn.jl` on the command line

```bash
julia --project=. src/fluxnn.jl
```
> Returns
```powershell
  Activating new project at `/var/folders/c3/1jkjmwxx5vncr00sw0gklknr0000gp/T/jl_HUwIKr`
   Resolving package versions...
    Updating `/private/var/folders/c3/1jkjmwxx5vncr00sw0gklknr0000gp/T/jl_HUwIKr/Project.toml`
  [a93c6f00] + DataFrames v1.6.1
  [587475ba] + Flux v0.14.12
  [dde4c033] + Metal v1.0.0
  [0b1bfda6] + OneHotArrays v0.2.5
    Updating `/private/var/folders/c3/1jkjmwxx5vncr00sw0gklknr0000gp/T/jl_HUwIKr/Manifest.toml`
  [621f4979] + AbstractFFTs v1.5.0
  [79e6a3ab] + Adapt v4.0.1
  [dce04be8] + ArgCheck v2.3.0
  [a9b6321e] + Atomix v0.1.0
⌅ [198e06fe] + BangBang v0.3.40
  [9718e550] + Baselet v0.1.1
  [fa961155] + CEnum v0.5.0
  [082447d4] + ChainRules v1.63.0
  [d360d2e6] + ChainRulesCore v1.22.1
  [bbf7d656] + CommonSubexpressions v0.3.0
  [34da2185] + Compat v4.14.0
  [a33af91c] + CompositionsBase v0.1.2
  [187b0558] + ConstructionBase v1.5.4
  [6add18c4] + ContextVariablesX v0.1.3
  [a8cc5b0e] + Crayons v4.1.1
  [9a962f9c] + DataAPI v1.16.0
  [a93c6f00] + DataFrames v1.6.1
  [864edb3b] + DataStructures v0.18.17
  [e2d170a0] + DataValueInterfaces v1.0.0
  [244e2a9f] + DefineSingletons v0.1.2
  [8bb1440f] + DelimitedFiles v1.9.1
  [163ba53b] + DiffResults v1.1.0
  [b552c78f] + DiffRules v1.15.1
  [ffbed154] + DocStringExtensions v0.9.3
  [e2ba6199] + ExprTools v0.1.10
  [cc61a311] + FLoops v0.2.1
  [b9860ae5] + FLoopsBase v0.1.1
  [1a297f60] + FillArrays v1.9.3
  [587475ba] + Flux v0.14.12
  [f6369f11] + ForwardDiff v0.10.36
  [d9f16b24] + Functors v0.4.7
  [0c68f7d7] + GPUArrays v10.0.2
  [46192b85] + GPUArraysCore v0.1.6
⌅ [61eb1bfa] + GPUCompiler v0.25.0
  [7869d1d1] + IRTools v0.4.12
  [22cec73e] + InitialValues v0.3.1
  [842dd82b] + InlineStrings v1.4.0
  [41ab1584] + InvertedIndices v1.3.0
  [92d709cd] + IrrationalConstants v0.2.2
  [82899510] + IteratorInterfaceExtensions v1.0.0
  [692b3bcd] + JLLWrappers v1.5.0
  [b14d175d] + JuliaVariables v0.2.4
  [63c18a36] + KernelAbstractions v0.9.17
  [929cbde3] + LLVM v6.6.0
  [b964fa9f] + LaTeXStrings v1.3.1
  [2ab3a3ac] + LogExpFunctions v0.3.27
  [d8e11817] + MLStyle v0.4.17
  [f1d291b0] + MLUtils v0.4.4
  [1914dd2f] + MacroTools v0.5.13
  [dde4c033] + Metal v1.0.0
⌅ [128add7d] + MicroCollections v0.1.4
  [e1d29d7a] + Missings v1.1.0
  [872c559c] + NNlib v0.9.12
  [77ba4419] + NaNMath v1.0.2
  [71a1bf82] + NameResolution v0.1.5
  [d8793406] + ObjectFile v0.4.1
  [e86c9b32] + ObjectiveC v1.1.0
  [0b1bfda6] + OneHotArrays v0.2.5
  [3bd65402] + Optimisers v0.3.2
  [bac558e1] + OrderedCollections v1.6.3
  [69de0a69] + Parsers v2.8.1
  [2dfb63ee] + PooledArrays v1.4.3
  [aea7be01] + PrecompileTools v1.2.0
  [21216c6a] + Preferences v1.4.1
  [8162dcfd] + PrettyPrint v0.2.0
  [08abe8d2] + PrettyTables v2.3.1
  [33c8b6b6] + ProgressLogging v0.1.4
  [c1ae055f] + RealDot v0.1.0
  [189a3867] + Reexport v1.2.2
  [ae029012] + Requires v1.3.0
  [6c6a2e73] + Scratch v1.2.1
  [91c51154] + SentinelArrays v1.4.1
  [efcf1570] + Setfield v1.1.1
  [605ecd9f] + ShowCases v0.1.0
  [699a6c99] + SimpleTraits v0.9.4
  [a2af1166] + SortingAlgorithms v1.2.1
  [dc90abb0] + SparseInverseSubset v0.1.2
  [276daf66] + SpecialFunctions v2.3.1
  [171d559e] + SplittablesBase v0.1.15
  [90137ffa] + StaticArrays v1.9.3
  [1e83bf80] + StaticArraysCore v1.4.2
  [82ae8749] + StatsAPI v1.7.0
  [2913bbd2] + StatsBase v0.34.2
  [892a3eda] + StringManipulation v0.3.4
  [09ab397b] + StructArrays v0.6.18
  [53d494c1] + StructIO v0.3.0
  [3783bdb8] + TableTraits v1.0.1
  [bd369af6] + Tables v1.11.1
  [a759f4b9] + TimerOutputs v0.5.23
⌃ [28d57a85] + Transducers v0.4.80
  [013be700] + UnsafeAtomics v0.2.1
  [d80eeb9a] + UnsafeAtomicsLLVM v0.1.3
  [e88e6eb3] + Zygote v0.6.69
  [700de1a5] + ZygoteRules v0.2.5
  [6e34b625] + Bzip2_jll v1.0.8+1
  [2e619515] + Expat_jll v2.5.0+0
  [dad2f222] + LLVMExtra_jll v0.0.29+0
  [7106de7a] + LibMPDec_jll v2.5.1+0
⌅ [e9f186c6] + Libffi_jll v3.2.2+1
  [0418c028] + Metal_LLVM_Tools_jll v0.5.1+0
  [458c3c95] + OpenSSL_jll v3.0.13+0
  [efe28fd5] + OpenSpecFun_jll v0.5.5+0
  [93d3a430] + Python_jll v3.10.13+0
  [76ed43ae] + SQLite_jll v3.45.0+0
  [ffd25f8a] + XZ_jll v5.6.0+0
  [0dad84c5] + ArgTools v1.1.1
  [56f22d72] + Artifacts
  [2a0f44e3] + Base64
  [ade2ca70] + Dates
  [8ba89e20] + Distributed
  [f43a241f] + Downloads v1.6.0
  [7b1f6079] + FileWatching
  [9fa8497b] + Future
  [b77e0a4c] + InteractiveUtils
  [4af54fe1] + LazyArtifacts
  [b27032c2] + LibCURL v0.6.4
  [76f85450] + LibGit2
  [8f399da3] + Libdl
  [37e2e46d] + LinearAlgebra
  [56ddb016] + Logging
  [d6f4376e] + Markdown
  [a63ad114] + Mmap
  [ca575930] + NetworkOptions v1.2.0
  [44cfe95a] + Pkg v1.10.0
  [de0858da] + Printf
  [3fa0cd96] + REPL
  [9a3f8284] + Random
  [ea8e919c] + SHA v0.7.0
  [9e88b42a] + Serialization
  [6462fe0b] + Sockets
  [2f01184e] + SparseArrays v1.10.0
  [10745b16] + Statistics v1.10.0
  [4607b0f0] + SuiteSparse
  [fa267f1f] + TOML v1.0.3
  [a4e569a6] + Tar v1.10.0
  [8dfed614] + Test
  [cf7118a7] + UUIDs
  [4ec0a83e] + Unicode
  [e66e0078] + CompilerSupportLibraries_jll v1.1.0+0
  [deac9b47] + LibCURL_jll v8.4.0+0
  [e37daf67] + LibGit2_jll v1.6.4+0
  [29816b5a] + LibSSH2_jll v1.11.0+1
  [c8ffd9c3] + MbedTLS_jll v2.28.2+1
  [14a3606d] + MozillaCACerts_jll v2023.1.10
  [4536629a] + OpenBLAS_jll v0.3.23+4
  [05823500] + OpenLibm_jll v0.8.1+2
  [bea87d4a] + SuiteSparse_jll v7.2.1+1
  [83775a58] + Zlib_jll v1.2.13+1
  [8e850b90] + libblastrampoline_jll v5.8.0+1
  [8e850ede] + nghttp2_jll v1.52.0+1
  [3f19e933] + p7zip_jll v17.4.0+2
        Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`
   Resolving package versions...
    Updating `/private/var/folders/c3/1jkjmwxx5vncr00sw0gklknr0000gp/T/jl_HUwIKr/Project.toml`
  [4382bb9f] + metaldemo v0.1.0 `~/Developer/JuliaExperiments`
    Updating `/private/var/folders/c3/1jkjmwxx5vncr00sw0gklknr0000gp/T/jl_HUwIKr/Manifest.toml`
  [a963bdd2] + AtomsBase v0.3.5
⌅ [ab4f0b2a] + BFloat16s v0.4.2
  [d1d4a3ce] + BitFlags v0.1.8
  [e1450e63] + BufferedStreams v1.2.1
  [336ed68f] + CSV v0.10.12
  [46823bd8] + Chemfiles v0.10.41
  [944b1d66] + CodecZlib v0.7.4
  [35d6a980] + ColorSchemes v3.24.0
  [3da002f7] + ColorTypes v0.11.4
  [c3611d14] + ColorVectorSpace v0.10.0
  [5ae59095] + Colors v0.12.10
  [f0e56b4a] + ConcurrentUtilities v2.3.1
  [124859b0] + DataDeps v0.7.13
  [460bff9d] + ExceptionUnwrapping v0.1.10
  [5789e2e9] + FileIO v1.16.2
  [48062228] + FilePathsBase v0.9.21
  [53c48c17] + FixedPointNumbers v0.8.4
  [92fee26a] + GZip v0.6.2
  [c27321d9] + Glob v1.3.1
  [f67ccb44] + HDF5 v0.17.1
  [cd3eb016] + HTTP v1.10.2
  [c817782e] + ImageBase v0.1.7
  [a09fc81d] + ImageCore v0.10.2
  [4e3cecfd] + ImageShow v0.3.8
  [7d512f48] + InternedStrings v0.7.0
  [033835bb] + JLD2 v0.4.46
  [0f8b85d8] + JSON3 v1.14.0
  [8cdb02fc] + LazyModules v0.3.1
  [e6f89c97] + LoggingExtras v1.0.3
  [23992714] + MAT v0.10.6
  [eb30cadb] + MLDatasets v0.7.14
  [3da0fdf6] + MPIPreferences v0.1.10
  [dbb5928d] + MappedArrays v0.4.2
  [739be429] + MbedTLS v1.1.9
  [e94cdb99] + MosaicViews v0.3.4
  [15e1cf62] + NPZ v0.4.3
  [6fe1bfb0] + OffsetArrays v1.13.0
  [4d8831e6] + OpenSSL v1.4.1
  [5432bcbf] + PaddedViews v0.5.12
  [7b2266bf] + PeriodicTable v1.2.1
  [fbb45041] + Pickle v0.3.3
  [777ac1f9] + SimpleBufferStream v1.1.0
  [cae243ae] + StackViews v0.1.1
⌅ [5e0ebb24] + Strided v1.2.3
  [69024149] + StringEncodings v0.3.7
  [856f2bd8] + StructTypes v1.10.0
  [62fd8b95] + TensorCore v0.1.1
  [3bb67fe8] + TranscodingStreams v0.10.3
  [9d95972d] + TupleTools v1.5.0
  [5c2747f8] + URIs v1.5.1
  [1986cc42] + Unitful v1.19.0
  [a7773ee8] + UnitfulAtomic v1.0.0
  [ea10d353] + WeakRefStrings v1.4.2
  [76eceee3] + WorkerUtilities v1.6.1
  [a5390f91] + ZipFile v0.10.1
  [4382bb9f] + metaldemo v0.1.0 `~/Developer/JuliaExperiments`
  [78a364fa] + Chemfiles_jll v0.10.4+0
  [0234f1f7] + HDF5_jll v1.14.3+1
  [e33a78d0] + Hwloc_jll v2.10.0+0
  [94ce4f54] + Libiconv_jll v1.17.0+0
  [7cb0a576] + MPICH_jll v4.2.0+0
  [f1f71cc9] + MPItrampoline_jll v5.3.2+0
  [9237b28f] + MicrosoftMPI_jll v10.1.4+2
⌅ [fe0851c0] + OpenMPI_jll v4.1.6+0
  [477f73a3] + libaec_jll v1.0.6+1
        Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated -m`
[ Info: Using backend: Metal.
Epoch 1: Accuracy=0.5931232091690545, Loss=0.5558569233923004
```