-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.m
77 lines (58 loc) · 2.12 KB
/
train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
outputFolder = fullfile('WPI');
imagesFolder = fullfile(outputFolder, 'images');
labelsFolder = fullfile(outputFolder, 'labels');
imds = imageDatastore(imagesFolder, ...
'FileExtensions', '.mat', ...
'ReadFcn', @helperImageMatReader);
classNames = [
"background"
"car"
"truck"
];
numClasses = numel(classNames);
% Specify label IDs from 1 to the number of classes.
labelIDs = 1 : numClasses;
pxds = pixelLabelDatastore(labelsFolder, classNames, labelIDs);
imageNumber = 225;
% Point cloud (channels 1, 2, and 3 are for location, channel 4 is for intensity).
I = readimage(imds, imageNumber);
labelMap = readimage(pxds, imageNumber);
figure;
helperDisplayLidarOverlayImage(I, labelMap, classNames);
title('Ground Truth');
[imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = ...
helperPartitionLidarData(imds, pxds);
trainingData = combine(imdsTrain, pxdsTrain);
validationData = combine(imdsVal, pxdsVal);
augmentedTrainingData = transform(trainingData, @(x) augmentData(x));
tbl = countEachLabel(pxds);
tbl(:,{'Name','PixelCount','ImagePixelCount'})
imageFreq = tbl.PixelCount ./ tbl.ImagePixelCount;
classWeights = median(imageFreq) ./ imageFreq;
inputSize = [64 1024 5];
lgraph = createLiPoSeg(inputSize, classNames, classWeights);
analyzeNetwork(lgraph)
maxEpochs = 30;
initialLearningRate= 5e-4;
miniBatchSize = 8;
l2reg = 2e-4;
options = trainingOptions('rmsprop', ...
'InitialLearnRate', initialLearningRate, ...
'L2Regularization', l2reg, ...
'MaxEpochs', maxEpochs, ...
'MiniBatchSize', miniBatchSize, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.1, ...
'LearnRateDropPeriod', 10, ...
'ValidationData', validationData, ...
'Plots', 'training-progress', ...
'VerboseFrequency', 60, ...
'ValidationFrequency',120, ...
'ExecutionEnvironment', 'parallel');
doTraining = true;
if doTraining
[net, info] = trainNetwork(trainingData, lgraph, options);
% else
% pretrainedNetwork = load('trainedPointSegNet.mat');
% net = pretrainedNetwork.net;
end