In [1]:
using Pkg

using Flux
using MLDatasets
using Images
using Zygote

In [2]:
train_X, train_y = MNIST(split=:train)[:]
test_X, test_y = MNIST(split=:test)[:]

# Normalize the images to the range [0, 1] and convert to Float32
train_X = Float32.(train_X ./ 255.0)
test_X = Float32.(test_X ./ 255.0)

# Reshape the data to match the expected input dimensions (length, channels, batch_size)
train_X = reshape(train_X, size(train_X, 1) * size(train_X, 2), 1, size(train_X, 3))
test_X = reshape(test_X, size(test_X, 1) * size(test_X, 2), 1, size(test_X, 3))

# Ensure train_y and test_y are in the correct format
train_y = Flux.onehotbatch(train_y, 0:9)
test_y = Flux.onehotbatch(test_y, 0:9)

;

In [3]:
model = Chain(
    Dense(784, 128, relu),
    Dense(128, 256, relu),
    x -> reshape(x, (4, 4, 16, size(x, 2))), # Correct Reshape
    ConvTranspose((3, 3), 16 => 256, relu, pad=1),
    Upsample((2, 2)),
    ConvTranspose((3, 3), 256 => 128, relu, pad=1),
    Upsample((2, 2)),
    ConvTranspose((4, 4), 128 => 1, pad=1), # Adjusted kernel size and padding
    x -> reshape(x, (28, 28, 1, size(x, 4))),
    x -> x[:, :, 1, :]
)


In [None]:
loss(m, x, y) = mse(m(x), y)

opt = ADAM();

In [None]:
function train_model!(model, train_X, train_Y, opt, epochs, batch_size)
    data_loader = Flux.DataLoader((train_X, train_Y), batchsize=batch_size, shuffle=true)
    
    opt_state = Flux.setup(opt, model)  
    total_loss = []

    for epoch in 1:epochs
        epoch_loss = 0
        for (x, y) in data_loader
            gs = Flux.gradient(model -> Flux.Losses.mse(model(x), y), model)[1]  
            Flux.update!(opt_state, Flux.trainable(model), gs)
            epoch_loss += Flux.Losses.mse(model(x), y)
        end
        println("Epoch $epoch complete")
        push!(total_loss, epoch_loss)
    end
    return total_loss
end;

In [None]:
epochs = 3
batch_size = 64
;

In [None]:
loss_list = train_model!(model, train_X, train_y, opt, epochs, batch_size);