Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
249 lines (228 sloc) 8.81 KB
using System;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;
using NUnit.Framework;
namespace MiscCodeTests
{
[TestFixture]
public class SimpleNeuralNet
{
// translation of Siraj python code from video https://www.youtube.com/watch?v=h3l4qz76JhQ into C#
//https://github.com/stmorgan/pythonNNexample/blob/master/PythonNNExampleFromSirajology.py
// if you want to rewrite this as command line, replace Run with the static Main method
[Test]
public void Run()
{
var input = new double[,]
{
{0,0,1},
{0,1,1},
{1,0,1},
{1,1,1},
};
var output = new double[,]
{
{0},{1},{1},{0}
};
// uncomment this and comment the code below to get random variables
/*var rand = new Random(1);
var syn0 = new double[3,4].PopulateWithRandom(rand);
var syn1 = new double[4,1].PopulateWithRandom(rand);*/
// this is hardcoded to match the siraj example generated by python
// input to hidden synopses
var syn0 = new [,]
{
{-0.16595599, 0.44064899, -0.99977125, -0.39533485},
{-0.70648822, -0.81532281, -0.62747958, -0.30887855},
{-0.20646505, 0.07763347, -0.16161097, 0.370439 }
};
// hidden to output synopses
var syn1 = new [,]
{
{-0.5910955 },
{ 0.75623487},
{-0.94522481},
{ 0.34093502}
};
var sw = Stopwatch.StartNew();
foreach (var times in Enumerable.Range(0,60000))
{
var l0 = input;
var l1 = l0.MultiplyByMatrix(syn0).Apply(Func);
var l2 = l1.MultiplyByMatrix(syn1).Apply(Func);
var errorsL2 = output.SubstractValues(l2);
if (times%10000==0)
Console.WriteLine("Error:"+errorsL2.AggregateEach(0.0, (n,e) => n + Math.Abs(e)) / (errorsL2.Columns()*errorsL2.Rows()));
var deltaL2 = errorsL2.MultiplyValues(l2.Apply(FuncDerivative));
var errorL1 = deltaL2.MultiplyByMatrix(syn1.Transpose());
var deltaL1 = errorL1.MultiplyValues(l1.Apply(FuncDerivative));
syn1 = syn1.AddValues(l1.Transpose().MultiplyByMatrix(deltaL2));
syn0 = syn0.AddValues(l0.Transpose().MultiplyByMatrix(deltaL1));
}
sw.Stop();
Console.WriteLine(sw.Elapsed);
}
private static double Func(double x)
{
return 1.0/(1.0 + Math.Exp(-x)); // sigmoid
}
private static double FuncDerivative(double x)
{
return x * (1.0 - x);
// return (1.0 / (1.0 + Math.Exp(-x))) * (1.0 - (1.0 / (1.0 + Math.Exp(-x))));// sigmoid derivative
//return Math.Sqrt(1-Math.Pow(x*2-1,2)); <- with this one the error disappears faster
}
}
internal static class SimpleNeuralNetExtensions
{
internal static T AggregateEach<T>(this double[,] a, T aggregator, Func<T,double,T> func)
{
for(var r = 0; r < a.GetLength(0); r++)//row
for(var c = 0; c < a.GetLength(1); c++)//column
aggregator = func(aggregator,a[r,c]);
return aggregator;
}
internal static double[,] SubstractValues(this double[,] a, double[,] b)
{
var result = new double[a.GetLength(0),a.GetLength(1)];
for(var r = 0; r < a.GetLength(0); r++)//row
for(var c = 0; c < a.GetLength(1); c++)//column
result[r,c] = a[r,c] - b[r,c];
return result;
}
internal static double[,] AddValues(this double[,] a, double[,] b)
{
var result = new double[a.GetLength(0),a.GetLength(1)];
for(var r = 0; r < a.GetLength(0); r++)//row
for(var c = 0; c < a.GetLength(1); c++)//column
result[r,c] = a[r,c] + b[r,c];
return result;
}
// not to be confused with matrix multiplication just multiplying the values !
internal static double[,] MultiplyValues(this double[,] a, double[,] b)
{
var result = new double[a.Rows(), b.Columns()];
for(var r = 0; r < a.Rows(); r++)//row
for(var c = 0; c < a.Columns(); c++)//column
result[r,c] = a[r,c] * b[r,c];
return result;
}
internal static double[,] PopulateWithRandom(this double[,] matrix, Random rand)
{
for(var r = 0; r < matrix.GetLength(0); r++)//row
for(var c = 0; c < matrix.GetLength(1); c++)//column
matrix[r,c] = rand.NextDouble()*2.0-1.0;
return matrix;
}
internal static double[,] MultiplyByMatrix(this double[,] a, double[,] b)
{
// if you cannot use unsafe, then this is:
// multiplying matrices without the use of unsafe
// uncomment it and comment the rest
/*
var rows = a.GetLength(0);
var cols = b.GetLength(1);
var colsOfFirstMatrix = a.GetLength(1);
var result = new double[rows, cols];
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
result[i, j] = 0;
for (int k = 0; k < colsOfFirstMatrix; k++) // OR k<b.GetLength(0)
result[i, j] = result[i, j] + a[i, k] * b[k, j];
}
}
return result;
*/
var result = new double[a.Rows(), b.Columns()];
int N = result.Rows();
int K = a.Columns();
int M = result.Columns();
int stride = b.Columns();
var t = new double[K];
unsafe
{
fixed (double* A = a)
fixed (double* B = b)
fixed (double* T = t)
fixed (double* R = result)
{
for (int j = 0; j < M; j++)
{
double* pb = B + j;
for (int k = 0; k < K; k++)
{
T[k] = *pb;
pb += stride;
}
double* pa = A;
double* pr = R + j;
for (int i = 0; i < N; i++)
{
double s = (double)0;
for (int k = 0; k < K; k++)
s += (double)((double)pa[k] * (double)T[k]);
*pr = (double)s;
pa += K;
pr += M;
}
}
}
}
return result;
}
internal static int Rows<T>(this T[,] matrix)
{
return matrix.GetLength(0);
}
internal static int Columns<T>(this T[,] matrix)
{
return matrix.GetLength(1);
}
internal static double[,] Apply(this double[,] matrix, Func<double,double> func)
{
var rows = matrix.Rows();
var cols = matrix.Columns();
var result = new double[rows, cols];
for(var r = 0; r < rows; r++)//row
for(var c = 0; c < cols; c++)//column
result[r,c] = func(matrix[r,c]);
return result;
}
internal static T[,] Transpose<T>(this T[,] matrix)
{
int w = matrix.GetLength(0);
int h = matrix.GetLength(1);
T[,] result = new T[h, w];
for (int i = 0; i < w; i++)
{
for (int j = 0; j < h; j++)
{
result[j, i] = matrix[i, j];
}
}
return result;
}
// to be used for debugging:
internal static void PrintInConsole(this double[,] arr)
{
int rowLength = arr.GetLength(0);
int colLength = arr.GetLength(1);
Console.WriteLine("{");
for (int i = 0; i < rowLength; i++)
{
Console.Write("\t{");
for (int j = 0; j < colLength; j++)
{
if(j > 0) Console.Write(",");
Console.Write("{0}", arr[i, j]);
}
if(i < (rowLength - 1)) Console.WriteLine("},");
else Console.WriteLine("}");
}
Console.WriteLine("}");
}
}
}