In [46]:
using JuNN

using Printf, Random
Random.seed!(0)

TaskLocalRNG()

In [47]:
using JLD2
X_train = load("./data_rnn/imdb_dataset_prepared.jld2", "X_train")
y_train = load("./data_rnn/imdb_dataset_prepared.jld2", "y_train")
X_test = load("./data_rnn/imdb_dataset_prepared.jld2", "X_test")
y_test = load("./data_rnn/imdb_dataset_prepared.jld2", "y_test")
embeddings = load("./data_rnn/imdb_dataset_prepared.jld2", "embeddings")
vocab = load("./data_rnn/imdb_dataset_prepared.jld2", "vocab")
nothing

In [48]:
println("X_train: ", size(X_train))
println("y_train: ", size(y_train))
println("X_test: ", size(X_test))
println("y_test: ", size(y_test))
println("embeddings: ", size(embeddings))
println("vocab: ", size(vocab))

X_train: (130, 40000)
y_train: (1, 40000)
X_test: (130, 10000)
y_test: (1, 10000)
embeddings: (50, 12849)
vocab: (12849,)


In [49]:
vocab_size = length(vocab)
embed_dim = size(embeddings, 1)     # 50
sequence_length = size(X_train, 1)  # 130
batch_size = 128

128

In [50]:
model = Chain(
    Embedding(vocab_size, embed_dim, name="embedding"),
    RNN(embed_dim, 16, name="rnn_layer"),
    Dense((16 => 1), σ, name="output_layer")
)
    
model.layers[1].weights.output .= embeddings
dataset = DataLoader((X_train, y_train), batchsize=batch_size, shuffle=true)
testset = DataLoader((X_test, y_test), batchsize=batch_size, shuffle=false)

accuracy(y_true, y_pred) = mean((y_true .> 0.5f0) .== (y_pred .> 0.5f0))

accuracy (generic function with 1 method)

In [51]:
net = NeuralNetwork(model, RMSProp(), binary_cross_entropy, accuracy, batch_size, seq_length=sequence_length)

Adding Embedding layer parameters: 1
Adding RNN layer parameters: 2
Adding RNN layer bias: 2
Adding Dense layer parameters: 3
Adding Dense layer bias: 3


NeuralNetwork(Chain{Tuple{Embedding, RNN, Dense{typeof(σ)}}}((Embedding(var embedding
 ┣━ ^ 50×12849 Matrix{Float32}
 ┗━ ∇ Nothing), RNN(RNNCell(var rnn_layer_W_ih
 ┣━ ^ 16×50 Matrix{Float32}
 ┗━ ∇ Nothing, var rnn_layer_W_hh
 ┣━ ^ 16×16 Matrix{Float32}
 ┗━ ∇ Nothing, var rnn_layer_bias
 ┣━ ^ 16×1 Matrix{Float32}
 ┗━ ∇ Nothing, JuAD.ReLU), 16), Dense{typeof(σ)}(var output_layer
 ┣━ ^ 1×16 Matrix{Float32}
 ┗━ ∇ Nothing, var output_layer_bias
 ┣━ ^ 1-element Vector{Float32}
 ┗━ ∇ Nothing, JuAD.σ))), RMSProp(0.001f0, 0.9f0, 1.1920929f-7, IdDict{Any, Array{Float32}}()), JuNN.binary_cross_entropy, Main.accuracy, var x_input
 ┣━ ^ 130×128 Matrix{Int32}
 ┗━ ∇ Nothing, var y_true
 ┣━ ^ 1×128 Matrix{Float32}
 ┗━ ∇ Nothing, op.?(typeof(σ)), JuAD.GraphNode[const -1.0, var y_true
 ┣━ ^ 1×128 Matrix{Float32}
 ┗━ ∇ Nothing, op.?(typeof(*)), var output_layer
 ┣━ ^ 1×16 Matrix{Float32}
 ┗━ ∇ Nothing, var rnn_layer_W_ih
 ┣━ ^ 16×50 Matrix{Float32}
 ┗━ ∇ Nothing, var embedding
 ┣━ ^ 50×12849 Matrix{Floa

In [52]:
epochs = 12
for epoch in 1:epochs
    t = @time begin
        train_loss, train_acc = train!(net, dataset)
    end
    
    test_loss, test_acc = evaluate(net, testset)
    @printf("Epoch %d/%d: Train Loss: %.4f, Train Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f\n",
            epoch, epochs, train_loss, train_acc, test_loss, test_acc)
end

 30.914038 seconds (82.25 M allocations: 136.616 GiB, 34.22% gc time, 0.27% compilation time)
Epoch 1/12: Train Loss: 0.6980, Train Acc: 0.5090, Test Loss: 0.6950, Test Acc: 0.5006
 31.098906 seconds (81.90 M allocations: 136.596 GiB, 34.10% gc time)
Epoch 2/12: Train Loss: 0.6908, Train Acc: 0.5190, Test Loss: 0.6938, Test Acc: 0.5052
 31.321431 seconds (81.90 M allocations: 136.596 GiB, 33.90% gc time)
Epoch 3/12: Train Loss: 0.6860, Train Acc: 0.5296, Test Loss: 0.6921, Test Acc: 0.5170
 31.391750 seconds (81.90 M allocations: 136.596 GiB, 33.80% gc time)
Epoch 4/12: Train Loss: 0.6789, Train Acc: 0.5442, Test Loss: 0.6780, Test Acc: 0.6452
 31.660940 seconds (81.90 M allocations: 136.596 GiB, 33.89% gc time)
Epoch 5/12: Train Loss: 0.6276, Train Acc: 0.6720, Test Loss: 0.6274, Test Acc: 0.6776
 32.059743 seconds (81.90 M allocations: 136.596 GiB, 34.21% gc time)
Epoch 6/12: Train Loss: 0.6077, Train Acc: 0.6927, Test Loss: 0.6168, Test Acc: 0.6983
 32.612973 seconds (81.90 M alloca