Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,107 @@
package main

import (
"fmt"
"image"
"log"
"time"
"math/rand"

"github.com/hajimehoshi/ebiten/v2"
"gonum.org/v1/plot/plotter"
)

func f(x float64) float64 { return x*x + 5*x - 3 }
func df(x float64) float64 { return 2*x + 5 }
const (
screenWidth, screenHeight = 640, 480
randMin, randMax = -10, 10
epochs, lr = 1000000, 0.0001
plotMinX, plotMaxX, plotMinY, plotMaxY = -10, 10, -50, 100
inputPointsMinY, inputPointsMaxY, inputPoints = -20, 20, 10
)

// Function points are spawed along
func f(x float64) float64 {
return x*x + 5*x - 3
}

// func df(f func(float64) float64) func(float64) float64 {
// return func(x float64) float64 {
// dx := 1e-10 // dx -> 0
// return (f(x+dx) - f(x)) / dx
// }
// }

// Inference for 1 argument
func i(x, w, b float64) float64 { return w*x + b }

// For all the input data
func inference(x []float64, w, b float64) (out []float64) {
for _, v := range x {
out = append(out, i(v, w, b))
}
return
}

// func loss(y, labels []float64) float64 {
// var errSum float64
// for i := range labels {
// errSum += math.Pow((y[i] - labels[i]), 2)
// }
// return errSum / float64(len(labels)) // n
// }

func gradient(xs, ys, labels []float64) (w, b float64) {
for i := 0; i < len(labels); i++ {
dif := ys[i] - labels[i]
w += dif * xs[i]
b += dif
}
n := float64(len(labels))
w *= 2 / n
b *= 2 / n
return
}

func train(inputs, labels []float64) (w, b float64) {
randFloat64 := func() float64 {
return randMin + rand.Float64()*(randMax-randMin)
}
w, b = randFloat64(), randFloat64()
var dw, db float64
for i := 0; i < epochs; i++ {
dw, db = gradient(labels, inference(inputs, w, b), inputs)
w -= dw * lr
b -= db * lr
fmt.Println(w, b)
}
return
}

func randPoints(f func(float64) float64, inputPointsMinY, inputPointsMaxY float64, inputPoints int) (xs, labels []float64) {
for i := 0; i < inputPoints; i++ {
x := plotMinX + (plotMaxX-plotMinX)*rand.Float64()
inputPointsY := inputPointsMinY + (inputPointsMaxY-inputPointsMinY)*rand.Float64()
xs = append(xs, x)
labels = append(labels, f(x)+inputPointsY)
}
return
}

func main() {
ebiten.SetWindowSize(640, 480)
ebiten.SetWindowSize(screenWidth, screenHeight)
ebiten.SetWindowTitle("Gradient descent")

img := make(chan *image.RGBA, 1)
go func() {
p := Plot(-5, 0, 0.1, f)
x := 0.0
img <- p(x)
for i := 0; i < 50; i++ {
time.Sleep(30 * time.Millisecond)
x -= df(x) * 0.1
img <- p(x)
}
}()

inputs, labels := randPoints(f, inputPointsMinY, inputPointsMaxY, inputPoints)
var points plotter.XYs
for i := 0; i < inputPoints; i++ {
points = append(points, plotter.XY{X: inputs[i], Y: labels[i]})
}
pointsScatter, _ := plotter.NewScatter(points)
fp := plotter.NewFunction(f)
w, b := train(inputs, labels)
fmt.Println(w, b)
ap := plotter.NewFunction(func(x float64) float64 { return w*x + b })
img <- Plot(pointsScatter, fp, ap)
if err := ebiten.RunGame(&App{Img: img}); err != nil {
log.Fatal(err)
}
Expand Down
59 changes: 11 additions & 48 deletions plot.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,23 @@
package main

import (
"fmt"
"image"
"image/color"
"log"

"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"gonum.org/v1/plot/vg/draw"
"gonum.org/v1/plot/vg/vgimg"
)

func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA {
var pts plotter.XYs
for x := xmin; x <= xmax; x += xstep {
pts = append(pts, plotter.XY{X: x, Y: f(x)})
}
fn, err := plotter.NewLine(pts)
if err != nil {
log.Fatalf("Failed to NewLine: %v", err)
}
fn.Color = color.RGBA{B: 255, A: 255}
return func(x float64) *image.RGBA {
pts := plotter.XYs{{X: x, Y: f(x)}}
xScatter, err := plotter.NewScatter(pts)
if err != nil {
log.Fatalf("Failed to NewScatter: %v", err)
}
xScatter.Color = color.RGBA{R: 255, A: 255}

labels, err := plotter.NewLabels(plotter.XYLabels{
XYs: pts,
Labels: []string{fmt.Sprintf("x = %.2f", x)},
})
labels.Offset = vg.Point{X: -10, Y: 15}
if err != nil {
log.Fatalf("Failed to NewLabels: %v", err)
}

p := plot.New()
p.Add(
plotter.NewGrid(),
fn,
xScatter,
labels,
)
p.Legend.Add("f(x)", fn)
p.Legend.Add("x", xScatter)
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"

img := image.NewRGBA(image.Rect(0, 0, 640, 480))
c := vgimg.NewWith(vgimg.UseImage(img))
p.Draw(draw.New(c))
return c.Image().(*image.RGBA)
}
func Plot(ps ...plot.Plotter) *image.RGBA {
p := plot.New()
p.X.Min, p.X.Max = plotMinX, plotMaxX
p.Y.Min, p.Y.Max = plotMinY, plotMaxY
p.Add(append([]plot.Plotter{
plotter.NewGrid(),
}, ps...)...)
img := image.NewRGBA(image.Rect(0, 0, screenWidth, screenHeight))
c := vgimg.NewWith(vgimg.UseImage(img))
p.Draw(draw.New(c))
return c.Image().(*image.RGBA)
}