In [2]:
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.8.0
%maven ai.djl:basicdataset:0.8.0
%maven ai.djl.mxnet:mxnet-engine:0.8.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.8.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport

In [3]:
import ai.djl.*;
import ai.djl.basicdataset.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.nn.*;
import ai.djl.nn.core.*;
import ai.djl.repository.zoo.*;
import ai.djl.training.*;
import ai.djl.training.dataset.*;
import ai.djl.training.initializer.*;
import ai.djl.training.listener.*;
import ai.djl.training.loss.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.tracker.*;
import ai.djl.training.util.*;
import ai.djl.translate.*;
import java.nio.file.*;
import java.util.*;
import java.util.concurrent.*;

1. 请利用Criteria API读取DJL ModelZoo里的预训练模型

In [4]:
Criteria<Image, Classifications> criteria = Criteria.builder()
     //选择需要读取的预训练模型
    .setTypes(Image.class, Classifications.class)
    .optProgress(new ProgressBar())
    .optArtifactId("resnet")
    .optFilter("layers", "50")
    .optFilter("flavor", "v1")
    .build();
Model model = ModelZoo.loadModel(criteria);

Loading:     100% |████████████████████████████████████████|


2. 去掉预训练模型的最后一个全连接层， 加上一个102个分类的全连接层(Linear Block)

In [5]:
SymbolBlock block = (SymbolBlock) model.getBlock();
block.removeLastBlock();

SequentialBlock newBlock = new SequentialBlock();
newBlock.add(block);
//添加一个batch flatten层用来把前面的二维输出转化为一维，给全连接层
//添加一个新的102分类全连接层
newBlock.add(Blocks.batchFlattenBlock());
newBlock.add(Linear.builder().setUnits(10).build());
model.setBlock(newBlock);

3. 准备数据集： [102分类花朵数据集](ttps://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html)

下载地址：https://d2l-java-resources.s3.amazonaws.com/flower_dataset.zip

In [15]:
import ai.djl.training.util.DownloadUtils;
import ai.djl.util.ZipUtils;

URL url = new URL("https://d2l-java-resources.s3.amazonaws.com/flower_dataset.zip");
ZipUtils.unzip(url.openStream(), Paths.get("./"))

In [None]:
import ai.djl.basicdataset.ImageFolder;
import ai.djl.repository.Repository;

int batchSize = 32;
float[] mean = {0.485f, 0.456f, 0.406f};
float[] std = {0.229f, 0.224f, 0.225f};
int resize_w = 224;
int resize_h = 224;


ImageFolder trainDataset =
    ImageFolder.builder()
    .setRepository(Repository.newInstance("flower_train", "flower_dataset/train"))
    .optPipeline(
        // create preprocess pipeline you want
        new Pipeline()
            .add(new CenterCrop())
            .add(new Resize(resize_w, resize_h))
            .add(new ToTensor())
            .add(new Normalize(mean, std)))
    .setSampling(batchSize, true)
    .build();
// call prepare before using
train_dataset.prepare(new ProgressBar());
train_dataset.getSynset();

ImageFolder testDataset =
    ImageFolder.builder()
    .setRepository(Repository.newInstance("flower_test", "flower_dataset/test"))
    .optPipeline(
        // create preprocess pipeline you want
        new Pipeline()
            .add(new CenterCrop())
            .add(new Resize(resize_w, resize_h))
            .add(new ToTensor())
            .add(new Normalize(mean, std)))
    .setSampling(batchSize, true)
    .build();

4. 配置TrainingConfig, 选择softmaxCrossEntropy作为损失函数，Accuracy作为Evaluator，在一个GPU上进行训练

In [None]:
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
    //softmaxCrossEntropyLoss is a standard loss for classification problems
    .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
    .optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
    .addTrainingListeners(TrainingListener.Defaults.basic());

Trainer trainer = model.newTrainer(config);

In [None]:
int epoch = 10;
Shape inputShape = new Shape(1, 3, resize_w, resize_h);
trainer.initialize(inputShape);
trainer.setMetrics(new Metrics());

5. 用EasyTrain的fit方法进行训练

In [None]:
EasyTrain.fit(trainer, epoch, trainDataset, testDataset)

Training:     71% |█████████████████████████████           | 

6. 保存模型到本地

In [None]:
Path modelDir = Paths.get("build/resnet");
Files.createDirectories(modelDir);

model.setProperty("Epoch", String.valueOf(epoch));
model.save(modelDir, "resnet");

7. 读取刚刚保存的模型，对一张花朵图片做预测