In [9]:
// Add the maven dependencies
%maven ai.djl:api:0.28.0
%maven ai.djl:basicdataset:0.28.0
%maven ai.djl:model-zoo:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36

In [23]:
// Import some packages
import java.nio.file.*;

import ai.djl.*;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.util.*;
import ai.djl.basicmodelzoo.cv.classification.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.modality.cv.util.NDImageUtils;

In [14]:
Application application = Application.CV.IMAGE_CLASSIFICATION;

long inputSize = 28*28;
long outputSize = 10;

In [15]:
// The core data type used for working with deep learning is the NDArray. 
// An NDArray represents a multidimensional, fixed-size homogeneous array. 
// It has very similar behavior to the Numpy python package with the addition of efficient computing. 
// We also have a helper class, the NDList which is a list of NDArrays which can have different sizes and data types.

// Basically building a basic NN here
SequentialBlock block = new SequentialBlock();
block.add(Blocks.batchFlattenBlock(inputSize));
block.add(Linear.builder().setUnits(128).build());
block.add(Activation::relu); // RELU activation
block.add(Linear.builder().setUnits(64).build());
block.add(Activation::relu);
block.add(Linear.builder().setUnits(outputSize).build());

block

SequentialBlock {
	batchFlatten
	Linear
	LambdaBlock
	Linear
	LambdaBlock
	Linear
}

In [16]:
// The sampler decides which and how many element from datasets are part of each batch when iterating through it. 
// We will have it randomly shuffle the elements for the batch and use a batchSize of 32. 
// The batchSize is usually the largest power of 2 that fits within memory.

int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());

[IJava-executor-4] INFO ai.djl.mxnet.jna.LibUtils - Downloading libgfortran.so.3 ...
[IJava-executor-4] INFO ai.djl.mxnet.jna.LibUtils - Downloading libgomp.so.1 ...
[IJava-executor-4] INFO ai.djl.mxnet.jna.LibUtils - Downloading libquadmath.so.0 ...
[IJava-executor-4] INFO ai.djl.mxnet.jna.LibUtils - Downloading libopenblas.so.0 ...
[IJava-executor-4] INFO ai.djl.mxnet.jna.LibUtils - Downloading libmxnet.so ...


Downloading: 100% |████████████████████████████████████████|


In [17]:
// Now we have our training set, lets make a model based on the neural net we made earlier
Model model = Model.newInstance("mlp");
model.setBlock(block);

In [19]:
// Training Loop
//  - REQUIRED Loss function: A loss function is used to measure how well our model matches the dataset. 
//     Because the lower value of the function is better, it's called the "loss" function. 
//     The Loss is the only required argument to the model.
//  - Evaluator function: An evaluator function is also used to measure how well our model matches the dataset. 
//     Unlike the loss, they are only there for people to look at and are not used for optimizing the model. 
//     Since many losses are not as intuitive, adding other evaluators such as Accuracy can help to understand how 
//     your model is doing. If you know of any useful evaluators, we recommend adding them.
//  - Training Listeners: The training listener adds additional functionality to the training process through 
//    a listener interface. This can include showing training progress, stopping early if training becomes 
//    undefined, or recording performance metrics. We offer several easy sets of default listeners.


DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    //softmaxCrossEntropyLoss is a standard loss for classification problems
    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
    .addTrainingListeners(TrainingListener.Defaults.logging());

// Now that we have our training configuration, we should create a new trainer for our model
Trainer trainer = model.newTrainer(config);

[IJava-executor-6] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[IJava-executor-6] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.9.0 in 0.039 ms.


In [20]:
// The first axis of the input shape is the batch size. This won't impact the parameter initialization, use 1
// The second axis of the input shape of the MLP - the number of pixels in the input image.
trainer.initialize(new Shape(1, 28 * 28));

// Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
int epoch = 2;
EasyTrain.fit(trainer, epoch, mnist, null);

Training:    100% |████████████████████████████████████████| 


[IJava-executor-7] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.


Training:    100% |████████████████████████████████████████| 


[IJava-executor-7] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.


In [21]:
// Save the trained model by writing it to disc (config, params, etc...)
Path modelDir = Paths.get("build/mlp");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));

model.save(modelDir, "mlp");

model

Model (
	Name: mlp
	Model location: /home/jupyter/build/mlp
	Data Type: float32
	Epoch: 2
)