In [None]:
using Pkg

#Pkg.add("ImageIO")
using Flux
using Flux: onehotbatch, onecold, crossentropy, @epochs

using Images, ImageIO
using Statistics

using MLDatasets


In [None]:
println(readdir("data"))
println(readdir("data/train"))
#println(readdir("data/train/FAKE"))

In [None]:
function load_cifake_data(base_path, split)
    images = []
    labels = []

    # Map class names to numeric labels
    class_map = Dict("FAKE" => 0, "REAL" => 1)

    # Construct the path for the split (train or test)
    split_path = joinpath(base_path, split)

    #println("Loading data from: ", split_path)

    for (dirpath, dirs, files) in walkdir(split_path)
        #println("Directory: ", dirpath)
        for file in files
            #println("Processing file: ", joinpath(dirpath, file))

            # Load image
            try
                img = load(joinpath(dirpath, file))
                img_array = Float32.(channelview(img))  # Convert RGB image to Float32 array
                push!(images, img_array)
            catch e
                println("Error loading image: ", joinpath(dirpath, file), " - ", e)
                continue
            end

            # Extract label from directory name
            class_name = basename(dirpath)
            if haskey(class_map, class_name)
                push!(labels, class_map[class_name])
            else
                #println("Warning: Unknown class name $class_name in path $dirpath")
            end
        end
    end

    # Check if any images were loaded
    if isempty(images)
        println("No images were loaded. Check the dataset path and structure.")
        return [], []
    end

    # Convert to arrays
    images = cat(images...; dims=4)  # Stack images along the 4th dimension
    labels = onehotbatch(labels, 0:1)  # One-hot encode labels (0 for FAKE, 1 for REAL)

    return images, labels
end

In [4]:
base_path = "data"  # Base directory for the dataset

# Load training data
x_train, y_train = load_cifake_data(base_path, "train")
println("x_train shape: ", size(x_train))  # Should be (H, W, C, N)
println("y_train shape: ", size(y_train))  # Should be (num_classes, N)

# Load testing data
x_test, y_test = load_cifake_data(base_path, "test")
println("x_test shape: ", size(x_test))  # Should be (H, W, C, N)
println("y_test shape: ", size(y_test))  # Should be (num_classes, N)# Partition the dataset into chunks


LoadError: InterruptException:

In [None]:
chunk_size = 1000  # Number of samples per chunk
batch_size = 16    # Batch size for training within each chunk

function partition_dataset(x, y, chunk_size)
    n_samples = size(x, 4)  # Number of samples (last dimension)
    chunks = []
    for i in 1:chunk_size:n_samples
        end_idx = min(i + chunk_size - 1, n_samples)
        push!(chunks, (x[:, :, :, i:end_idx], y[i:end_idx]))
    end
    return chunks
end


chunks = partition_dataset(x_train, y_train, chunk_size)


In [None]:
# Define a simple CNN.
# With target size 32x32, after three 2×2 poolings, spatial dims become 4×4.
model = Chain(
    Conv((3,3), 3=>16, pad=1, relu),  # input has 3 channels
    MaxPool((2,2)),
    Conv((3,3), 16=>32, pad=1, relu),
    MaxPool((2,2)),
    Conv((3,3), 32=>64, pad=1, relu),
    MaxPool((2,2)),
    flatten,
    Dense(64*4*4, 128, relu),
    Dense(128, 2),
    softmax
)

In [None]:
# Define the loss function.
loss(x, y) = crossentropy(model(x), y)

# Prepare the training data as a vector of tuples (for demonstration purposes).
dataset = [(X[:,:,:,i], Y[:,i]) for i in 1:size(X,4)]

# Use the ADAM optimizer.
opt = ADAM()

In [None]:
for epoch in 1:epochs
    println("Epoch $epoch starting...")

    for (x_chunk, y_chunk) in chunks
        # Create a DataLoader for the current chunk
        train_data = DataLoader((x_chunk, y_chunk), batchsize=batch_size, shuffle=true)

        # Train on the current chunk
        for (x, y) in train_data
            grads = Flux.gradient(m -> loss(m(x), y), model)
            Flux.update!(optimizer, model, grads)
        end
    end

    println("Epoch $epoch complete")
end

In [None]:
predictions = model(X)
acc = mean(onecold(predictions) .== onecold(Y))
println("Training Accuracy: $(acc*100)%")