Skip to content

Commit

Permalink
Load/Save net
Browse files Browse the repository at this point in the history
  • Loading branch information
radioman committed Nov 24, 2016
1 parent 39fc715 commit c0f3ed1
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 47 deletions.
104 changes: 90 additions & 14 deletions Examples/MnistDemo/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ internal class Program
private const string trainingImageFile = "train-images-idx3-ubyte.gz";
private const string testingLabelFile = "t10k-labels-idx1-ubyte.gz";
private const string testingImageFile = "t10k-images-idx3-ubyte.gz";

private void MnistDemo()
{
Directory.CreateDirectory(mnistFolder);
Expand All @@ -44,6 +44,7 @@ private void MnistDemo()
Console.WriteLine("Downloading Mnist training files...");
DownloadFile(urlMnist + trainingLabelFile, trainingLabelFilePath);
DownloadFile(urlMnist + trainingImageFile, trainingImageFilePath);

Console.WriteLine("Downloading Mnist testing files...");
DownloadFile(urlMnist + testingLabelFile, testingLabelFilePath);
DownloadFile(urlMnist + testingImageFile, testingImageFilePath);
Expand All @@ -60,14 +61,36 @@ private void MnistDemo()
return;
}

// Create network
this.net = new Net();
this.net.AddLayer(new InputLayer(24, 24, 1));
this.net.AddLayer(new ConvLayer(5, 5, 8) { Stride = 1, Pad = 2, Activation = Activation.Relu });
this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 });
this.net.AddLayer(new ConvLayer(5, 5, 16) { Stride = 1, Pad = 2, Activation = Activation.Relu });
this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 });
this.net.AddLayer(new SoftmaxLayer(10));
//ExtractImages();

Console.WriteLine($"load net?");
if (Console.ReadKey(true).Key == ConsoleKey.Enter)
{
var f = Path.Combine(mnistFolder, "net.bin");
Console.WriteLine($"loading: {f}");

this.net = Net.Load(f);
}
else
{
this.net = new Net();
this.net.AddLayer(new InputLayer(24, 24, 1));
this.net.AddLayer(new ConvLayer(12, 12, 8)
{
Stride = 1,
Pad = 2,
Activation = Activation.Relu
});
//this.net.AddLayer(new PoolLayer(2, 2) { Stride = 2 });
this.net.AddLayer(new ConvLayer(6, 6, 16)
{
Stride = 1,
Pad = 2,
Activation = Activation.Relu
});
//this.net.AddLayer(new PoolLayer(3, 3) { Stride = 3 });
this.net.AddLayer(new SoftmaxLayer(10));
}

this.trainer = new AdadeltaTrainer(this.net)
{
Expand All @@ -80,7 +103,49 @@ private void MnistDemo()
{
var sample = this.SampleTrainingInstance();
this.Step(sample);
} while (!Console.KeyAvailable);
}
while (!Console.KeyAvailable);
Console.ReadKey(true);

Console.WriteLine($"save net?");
if (Console.ReadKey(true).Key == ConsoleKey.Enter)
{
var f = Path.Combine(mnistFolder, "net.bin");
Console.WriteLine($"saving: {f}");

net.Save(f);
}
Console.WriteLine("done.");
Console.ReadKey();
}

private void ExtractImages()
{
int x = 0;
foreach (var t in training)
{
var dir = Path.Combine(mnistFolder, Path.GetFileNameWithoutExtension(trainingImageFile), $"{t.Label}");
if (!Directory.Exists(dir))
{
Directory.CreateDirectory(dir);
}
var f = Path.Combine(dir, $"{x++}.raw");

File.WriteAllBytes(f, t.Image);
}

x = 0;
foreach (var t in testing)
{
var dir = Path.Combine(mnistFolder, Path.GetFileNameWithoutExtension(testingImageFile), $"{t.Label}");
if (!Directory.Exists(dir))
{
Directory.CreateDirectory(dir);
}
var f = Path.Combine(dir, $"{x++}.raw");

File.WriteAllBytes(f, t.Image);
}
}

private void DownloadFile(string urlFile, string destFilepath)
Expand Down Expand Up @@ -129,7 +194,7 @@ private void Step(Item sample)
this.wLossWindow.Add(lossw);
this.trainAccWindow.Add(trainAcc);

if (this.stepCount % 200 == 0)
if (this.stepCount % 10 == 0)
{
if (this.xLossWindow.Count == this.xLossWindow.Capacity)
{
Expand All @@ -144,6 +209,8 @@ private void Step(Item sample)
Console.WriteLine("Example seen: {0} Fwd: {1}ms Bckw: {2}ms", this.stepCount,
Math.Round(this.trainer.ForwardTime.TotalMilliseconds, 2),
Math.Round(this.trainer.BackwardTime.TotalMilliseconds, 2));

Console.WriteLine();
}
}

Expand Down Expand Up @@ -235,11 +302,20 @@ private static void Main(string[] args)

private class Item
{
public Volume Volume { get; set; }
public Volume Volume
{
get; set;
}

public int Label { get; set; }
public int Label
{
get; set;
}

public bool IsValidation { get; set; }
public bool IsValidation
{
get; set;
}
}
}
}
99 changes: 66 additions & 33 deletions src/ConvNetSharp/Net.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.Serialization;
using System.Xml;
using ConvNetSharp.Layers;

namespace ConvNetSharp
{
[DataContract]
public class Net
{
private readonly List<LayerBase> layers = new List<LayerBase>();

public List<LayerBase> Layers
{
get { return this.layers; }
}
[DataMember]
public readonly List<LayerBase> Layers = new List<LayerBase>();

public void AddLayer(LayerBase layer)
{
int inputWidth = 0, inputHeight = 0, inputDepth = 0;
if (this.layers.Count > 0)
if (this.Layers.Count > 0)
{
inputWidth = this.layers[this.layers.Count - 1].OutputWidth;
inputHeight = this.layers[this.layers.Count - 1].OutputHeight;
inputDepth = this.layers[this.layers.Count - 1].OutputDepth;
inputWidth = this.Layers[this.Layers.Count - 1].OutputWidth;
inputHeight = this.Layers[this.Layers.Count - 1].OutputHeight;
inputDepth = this.Layers[this.Layers.Count - 1].OutputDepth;
}

var classificationLayer = layer as IClassificationLayer;
Expand All @@ -32,7 +32,7 @@ public void AddLayer(LayerBase layer)
inputHeight = fullyConnLayer.OutputHeight;
inputDepth = fullyConnLayer.OutputDepth;

this.layers.Add(fullyConnLayer);
this.Layers.Add(fullyConnLayer);
}

var regressionLayer = layer as RegressionLayer;
Expand All @@ -44,7 +44,7 @@ public void AddLayer(LayerBase layer)
inputHeight = fullyConnLayer.OutputHeight;
inputDepth = fullyConnLayer.OutputDepth;

this.layers.Add(fullyConnLayer);
this.Layers.Add(fullyConnLayer);
}

var dotProductLayer = layer as IDotProductLayer;
Expand All @@ -58,12 +58,12 @@ public void AddLayer(LayerBase layer)
}
}

if (this.layers.Count > 0)
if (this.Layers.Count > 0)
{
layer.Init(inputWidth, inputHeight, inputDepth);
}

this.layers.Add(layer);
this.Layers.Add(layer);

if (dotProductLayer != null)
{
Expand All @@ -74,45 +74,45 @@ public void AddLayer(LayerBase layer)
case Activation.Relu:
var reluLayer = new ReluLayer();
reluLayer.Init(layer.OutputWidth, layer.OutputHeight, layer.OutputDepth);
this.layers.Add(reluLayer);
this.Layers.Add(reluLayer);
break;
case Activation.Sigmoid:
var sigmoidLayer = new SigmoidLayer();
sigmoidLayer.Init(layer.OutputWidth, layer.OutputHeight, layer.OutputDepth);
this.layers.Add(sigmoidLayer);
this.Layers.Add(sigmoidLayer);
break;
case Activation.Tanh:
var tanhLayer = new TanhLayer();
tanhLayer.Init(layer.OutputWidth, layer.OutputHeight, layer.OutputDepth);
this.layers.Add(tanhLayer);
this.Layers.Add(tanhLayer);
break;
case Activation.Maxout:
var maxoutLayer = new MaxoutLayer { GroupSize = dotProductLayer.GroupSize };
maxoutLayer.Init(layer.OutputWidth, layer.OutputHeight, layer.OutputDepth);
this.layers.Add(maxoutLayer);
this.Layers.Add(maxoutLayer);
break;
default:
throw new ArgumentOutOfRangeException();
}
}

var lastLayer = this.layers[this.layers.Count - 1];
var lastLayer = this.Layers[this.Layers.Count - 1];

if (!(layer is DropOutLayer) && layer.DropProb.HasValue)
{
var dropOutLayer = new DropOutLayer(layer.DropProb.Value);
dropOutLayer.Init(lastLayer.OutputWidth, lastLayer.OutputHeight, lastLayer.OutputDepth);
this.layers.Add(dropOutLayer);
this.Layers.Add(dropOutLayer);
}
}

public Volume Forward(Volume volume, bool isTraining = false)
{
var activation = this.layers[0].Forward(volume, isTraining);
var activation = this.Layers[0].Forward(volume, isTraining);

for (var i = 1; i < this.layers.Count; i++)
for (var i = 1; i < this.Layers.Count; i++)
{
var layerBase = this.layers[i];
var layerBase = this.Layers[i];
activation = layerBase.Forward(activation, isTraining);
}

Expand All @@ -123,7 +123,7 @@ public double GetCostLoss(Volume volume, double y)
{
this.Forward(volume);

var lastLayer = this.layers[this.layers.Count - 1] as ILastLayer;
var lastLayer = this.Layers[this.Layers.Count - 1] as ILastLayer;
if (lastLayer != null)
{
var loss = lastLayer.Backward(y);
Expand All @@ -137,7 +137,7 @@ public double GetCostLoss(Volume volume, double[] y)
{
this.Forward(volume);

var lastLayer = this.layers[this.layers.Count - 1] as ILastLayer;
var lastLayer = this.Layers[this.Layers.Count - 1] as ILastLayer;
if (lastLayer != null)
{
var loss = lastLayer.Backward(y);
Expand All @@ -149,15 +149,15 @@ public double GetCostLoss(Volume volume, double[] y)

public double Backward(double y)
{
var n = this.layers.Count;
var lastLayer = this.layers[n - 1] as ILastLayer;
var n = this.Layers.Count;
var lastLayer = this.Layers[n - 1] as ILastLayer;
if (lastLayer != null)
{
var loss = lastLayer.Backward(y); // last layer assumed to be loss layer
for (var i = n - 2; i >= 0; i--)
{
// first layer assumed input
this.layers[i].Backward();
this.Layers[i].Backward();
}
return loss;
}
Expand All @@ -167,15 +167,15 @@ public double Backward(double y)

public double Backward(double[] y)
{
var n = this.layers.Count;
var lastLayer = this.layers[n - 1] as ILastLayer;
var n = this.Layers.Count;
var lastLayer = this.Layers[n - 1] as ILastLayer;
if (lastLayer != null)
{
var loss = lastLayer.Backward(y); // last layer assumed to be loss layer
for (var i = n - 2; i >= 0; i--)
{
// first layer assumed input
this.layers[i].Backward();
this.Layers[i].Backward();
}
return loss;
}
Expand All @@ -187,7 +187,7 @@ public int GetPrediction()
{
// this is a convenience function for returning the argmax
// prediction, assuming the last layer of the net is a softmax
var softmaxLayer = this.layers[this.layers.Count - 1] as SoftmaxLayer;
var softmaxLayer = this.Layers[this.Layers.Count - 1] as SoftmaxLayer;
if (softmaxLayer == null)
{
throw new Exception("GetPrediction function assumes softmax as last layer of the net!");
Expand All @@ -213,13 +213,46 @@ public List<ParametersAndGradients> GetParametersAndGradients()
{
var response = new List<ParametersAndGradients>();

foreach (LayerBase t in this.layers)
foreach (LayerBase t in Layers)
{
List<ParametersAndGradients> parametersAndGradients = t.GetParametersAndGradients();
response.AddRange(parametersAndGradients);
}

return response;
}

public void Save(string fileName)
{
using (FileStream fs = new FileStream(fileName, FileMode.Create))
{
using (XmlDictionaryWriter bdw = XmlDictionaryWriter.CreateBinaryWriter(fs))
{
var ser = new DataContractSerializer(typeof(Net));
ser.WriteObject(bdw, this);
bdw.Flush();
}
}
}

public static Net Load(string fileName)
{
Net ret;
using (FileStream fs = new FileStream(fileName, FileMode.Open))
{
var q = new XmlDictionaryReaderQuotas()
{
MaxArrayLength = 1024 * 1024 * 10, // 10MB
MaxBytesPerRead = 1024 * 1024 * 10 // 10MB
};

using (var reader = XmlDictionaryReader.CreateBinaryReader(fs, q))
{
var ser = new DataContractSerializer(typeof(Net));
ret = (Net)ser.ReadObject(reader, true);
}
}
return ret;
}
}
}

0 comments on commit c0f3ed1

Please sign in to comment.