In [1]:
import Pkg

#Pkg.add("CSV")
#Pkg.add("Random")

using Zygote
using Flux
using Flux: onehotbatch, onecold, crossentropy
using MLDatasets
#using CUDA
using Statistics
using Plots
using Images
using Colors
using ImageTransformations
using MLUtils
using CSV
using Random
using DataFrames


In [2]:
train_X, train_y = MNIST(split=:train)[:]
test_X, test_y = MNIST(split=:test)[:]

# Normalize the images to the range [0, 1] and convert to Float32
train_X = Float32.(train_X ./ 255.0)
test_X = Float32.(test_X ./ 255.0)

# Reshape the data to match the expected input dimensions (length, channels, batch_size)
train_X = reshape(train_X, size(train_X, 1) * size(train_X, 2), 1, size(train_X, 3))
test_X = reshape(test_X, size(test_X, 1) * size(test_X, 2), 1, size(test_X, 3))

train_X = reshape(train_X, 28, 28, 1, size(train_X, 3))
test_X  = reshape(test_X, 28, 28, 1, size(test_X, 3))

;

In [3]:
model = Chain(
    Conv((3, 3), 1=>32, relu, pad=1),  # Increased filters
    MaxPool((2,2)),
    Conv((3,3), 32=>64, relu, pad=1),  # Increased filters
    MaxPool((2,2)),
    Flux.flatten,
    Dense(7*7*64, 256, relu),  # Increased neurons
    Dense(256, 7*7*64, relu),  # Increased neurons
    x -> reshape(x, (7, 7, 64, size(x, 2))),
    ConvTranspose((3, 3), 64=>32, relu, stride=2, pad=1),
    ConvTranspose((3, 3), 32=>1, σ, stride=2, pad=1),
    x -> reshape(x, :, size(x, 4)),
    Dense(25*25, 1),
    x -> reshape(x, size(x, 2))
)



Chain(
  Conv((3, 3), 1 => 32, relu, pad=1),   [90m# 320 parameters[39m
  MaxPool((2, 2)),
  Conv((3, 3), 32 => 64, relu, pad=1),  [90m# 18_496 parameters[39m
  MaxPool((2, 2)),
  Flux.flatten,
  Dense(3136 => 256, relu),             [90m# 803_072 parameters[39m
  Dense(256 => 3136, relu),             [90m# 805_952 parameters[39m
  var"#1#4"(),
  ConvTranspose((3, 3), 64 => 32, relu, pad=1, stride=2),  [90m# 18_464 parameters[39m
  ConvTranspose((3, 3), 32 => 1, σ, pad=1, stride=2),  [90m# 289 parameters[39m
  var"#2#5"(),
  Dense(625 => 1),                      [90m# 626 parameters[39m
  var"#3#6"(),
) [90m                  # Total: 14 arrays, [39m1_647_219 parameters, 6.285 MiB.

In [4]:

opt = ADAM();

In [5]:
function train_model!(model, train_X, train_Y, opt, epochs, batch_size)
    data_loader = Flux.DataLoader((train_X, train_Y), batchsize=batch_size, shuffle=true)
    
    opt_state = Flux.setup(opt, model)  
    total_loss = []

    for epoch in 1:epochs
        epoch_loss = 0
        for (x, y) in data_loader
            gs = Flux.gradient(model -> Flux.Losses.mse(model(x), y), model)[1]  # Compute gradient for loss
            Flux.update!(opt_state, Flux.trainable(model), gs)  # Update model parameters
            epoch_loss += Flux.Losses.mse(model(x), y)  # Add loss for this batch
        end
        println("Epoch $epoch complete, Loss: $(epoch_loss)")
        push!(total_loss, epoch_loss)  # Track total loss over epochs
    end
    return total_loss
end

train_model! (generic function with 1 method)

In [6]:
epochs = 50
batch_size = 64
;

In [8]:
loss_list = train_model!(model, train_X, train_y, opt, epochs, batch_size);

In [None]:
p1 = plot(1:epochs, loss_list, xlabel="Epoch", ylabel="Loss", title="Loss vs. Epoch", legend=false, xticks=false, yticks=false)
display(p1)

In [None]:
#visualize_layers(model, train_x[:, :, :, 1:1])

accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
println("Test Accuracy: ", accuracy(test_x, test_y))
