diff --git a/main.go b/main.go index e39985b..85d69d2 100644 --- a/main.go +++ b/main.go @@ -1,32 +1,107 @@ package main import ( + "fmt" "image" "log" - "time" + "math/rand" "github.com/hajimehoshi/ebiten/v2" + "gonum.org/v1/plot/plotter" ) -func f(x float64) float64 { return x*x + 5*x - 3 } -func df(x float64) float64 { return 2*x + 5 } +const ( + screenWidth, screenHeight = 640, 480 + randMin, randMax = -10, 10 + epochs, lr = 1000000, 0.0001 + plotMinX, plotMaxX, plotMinY, plotMaxY = -10, 10, -50, 100 + inputPointsMinY, inputPointsMaxY, inputPoints = -20, 20, 10 +) + +// Function points are spawed along +func f(x float64) float64 { + return x*x + 5*x - 3 +} + +// func df(f func(float64) float64) func(float64) float64 { +// return func(x float64) float64 { +// dx := 1e-10 // dx -> 0 +// return (f(x+dx) - f(x)) / dx +// } +// } + +// Inference for 1 argument +func i(x, w, b float64) float64 { return w*x + b } + +// For all the input data +func inference(x []float64, w, b float64) (out []float64) { + for _, v := range x { + out = append(out, i(v, w, b)) + } + return +} + +// func loss(y, labels []float64) float64 { +// var errSum float64 +// for i := range labels { +// errSum += math.Pow((y[i] - labels[i]), 2) +// } +// return errSum / float64(len(labels)) // n +// } + +func gradient(xs, ys, labels []float64) (w, b float64) { + for i := 0; i < len(labels); i++ { + dif := ys[i] - labels[i] + w += dif * xs[i] + b += dif + } + n := float64(len(labels)) + w *= 2 / n + b *= 2 / n + return +} + +func train(inputs, labels []float64) (w, b float64) { + randFloat64 := func() float64 { + return randMin + rand.Float64()*(randMax-randMin) + } + w, b = randFloat64(), randFloat64() + var dw, db float64 + for i := 0; i < epochs; i++ { + dw, db = gradient(labels, inference(inputs, w, b), inputs) + w -= dw * lr + b -= db * lr + fmt.Println(w, b) + } + return +} + +func randPoints(f func(float64) float64, inputPointsMinY, inputPointsMaxY float64, inputPoints int) (xs, labels []float64) { + for i := 0; i < inputPoints; i++ { + x := plotMinX + (plotMaxX-plotMinX)*rand.Float64() + inputPointsY := inputPointsMinY + (inputPointsMaxY-inputPointsMinY)*rand.Float64() + xs = append(xs, x) + labels = append(labels, f(x)+inputPointsY) + } + return +} func main() { - ebiten.SetWindowSize(640, 480) + ebiten.SetWindowSize(screenWidth, screenHeight) ebiten.SetWindowTitle("Gradient descent") img := make(chan *image.RGBA, 1) - go func() { - p := Plot(-5, 0, 0.1, f) - x := 0.0 - img <- p(x) - for i := 0; i < 50; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.1 - img <- p(x) - } - }() - + inputs, labels := randPoints(f, inputPointsMinY, inputPointsMaxY, inputPoints) + var points plotter.XYs + for i := 0; i < inputPoints; i++ { + points = append(points, plotter.XY{X: inputs[i], Y: labels[i]}) + } + pointsScatter, _ := plotter.NewScatter(points) + fp := plotter.NewFunction(f) + w, b := train(inputs, labels) + fmt.Println(w, b) + ap := plotter.NewFunction(func(x float64) float64 { return w*x + b }) + img <- Plot(pointsScatter, fp, ap) if err := ebiten.RunGame(&App{Img: img}); err != nil { log.Fatal(err) } diff --git a/plot.go b/plot.go index ebad8a9..7c612fa 100644 --- a/plot.go +++ b/plot.go @@ -1,60 +1,23 @@ package main import ( - "fmt" "image" - "image/color" - "log" "gonum.org/v1/plot" "gonum.org/v1/plot/plotter" - "gonum.org/v1/plot/vg" "gonum.org/v1/plot/vg/draw" "gonum.org/v1/plot/vg/vgimg" ) -func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA { - var pts plotter.XYs - for x := xmin; x <= xmax; x += xstep { - pts = append(pts, plotter.XY{X: x, Y: f(x)}) - } - fn, err := plotter.NewLine(pts) - if err != nil { - log.Fatalf("Failed to NewLine: %v", err) - } - fn.Color = color.RGBA{B: 255, A: 255} - return func(x float64) *image.RGBA { - pts := plotter.XYs{{X: x, Y: f(x)}} - xScatter, err := plotter.NewScatter(pts) - if err != nil { - log.Fatalf("Failed to NewScatter: %v", err) - } - xScatter.Color = color.RGBA{R: 255, A: 255} - - labels, err := plotter.NewLabels(plotter.XYLabels{ - XYs: pts, - Labels: []string{fmt.Sprintf("x = %.2f", x)}, - }) - labels.Offset = vg.Point{X: -10, Y: 15} - if err != nil { - log.Fatalf("Failed to NewLabels: %v", err) - } - - p := plot.New() - p.Add( - plotter.NewGrid(), - fn, - xScatter, - labels, - ) - p.Legend.Add("f(x)", fn) - p.Legend.Add("x", xScatter) - p.X.Label.Text = "X" - p.Y.Label.Text = "Y" - - img := image.NewRGBA(image.Rect(0, 0, 640, 480)) - c := vgimg.NewWith(vgimg.UseImage(img)) - p.Draw(draw.New(c)) - return c.Image().(*image.RGBA) - } +func Plot(ps ...plot.Plotter) *image.RGBA { + p := plot.New() + p.X.Min, p.X.Max = plotMinX, plotMaxX + p.Y.Min, p.Y.Max = plotMinY, plotMaxY + p.Add(append([]plot.Plotter{ + plotter.NewGrid(), + }, ps...)...) + img := image.NewRGBA(image.Rect(0, 0, screenWidth, screenHeight)) + c := vgimg.NewWith(vgimg.UseImage(img)) + p.Draw(draw.New(c)) + return c.Image().(*image.RGBA) }