diff --git a/Line.go b/Line.go new file mode 100644 index 0000000..2955b2e --- /dev/null +++ b/Line.go @@ -0,0 +1,44 @@ +package main + +import ( + "fmt" + "log" + "math" +) + +type line struct { + k float64 + b float64 +} + +func (l *line) y(x float64) float64 { + return l.k*x + l.b +} + +func NewLine() *line { + return &line{0, 0} +} + +func (l *line) Train(points []point, lr float64, epochs uint) error { + for i := uint(0); i < epochs; i++ { + y := make([]float64, len(points)) + for j := range points { + y[j] = l.y(points[j].x) + } + + var sum1, sum2 float64 + for j := range points { + sum1 += y[j] - points[j].y + sum2 += y[j] - points[j].y*points[j].x + } + + l.k -= lr * (2 / float64(len(points))) * sum2 + l.b -= lr * (2 / float64(len(points))) * sum1 + loss := sum1 / float64(len(points)) + if math.IsNaN(loss) || math.IsInf(loss, 0) { + return fmt.Errorf("training unsuccessful") + } + log.Printf("Epoch: %d/%d, Loss %f\n", i, epochs, loss) + } + return nil +} diff --git a/main.go b/main.go index e39985b..92d5f2b 100644 --- a/main.go +++ b/main.go @@ -3,31 +3,53 @@ package main import ( "image" "log" - "time" "github.com/hajimehoshi/ebiten/v2" ) -func f(x float64) float64 { return x*x + 5*x - 3 } -func df(x float64) float64 { return 2*x + 5 } +const ( + e = 1e-5 +) + +func f(x float64) float64 { return x*x*x*x + x*x + 5*x - 3 } + +func df(x float64) float64 { return 4*x*x*x + 2*x + 5 } + +// func Derivative(x float64, f func(x float64) float64) float64 { +// return (f(x+e) - f(x)) / e +// } func main() { + p := NewRandomPoints(10) + l := NewLine() + err := l.Train(p, 0.001, 5000) + if err != nil { + log.Fatal(err) + } ebiten.SetWindowSize(640, 480) ebiten.SetWindowTitle("Gradient descent") img := make(chan *image.RGBA, 1) go func() { - p := Plot(-5, 0, 0.1, f) - x := 0.0 - img <- p(x) - for i := 0; i < 50; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.1 - img <- p(x) + p := Plot(-30, 30, 0.1, f) + x := 30.0 + img <- p(x, false) + for i := 0; i < 500000; i++ { + // time.Sleep( * time.Millisecond) + x -= df(x) * 1e-6 + img <- p(x, false) } + }() if err := ebiten.RunGame(&App{Img: img}); err != nil { log.Fatal(err) } } + +func Abs(a float64) float64 { + if a < 0 { + return a * -1 + } + return a +} diff --git a/plot.go b/plot.go index ebad8a9..b6eae71 100644 --- a/plot.go +++ b/plot.go @@ -13,7 +13,7 @@ import ( "gonum.org/v1/plot/vg/vgimg" ) -func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA { +func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64, isMinimum bool) *image.RGBA { var pts plotter.XYs for x := xmin; x <= xmax; x += xstep { pts = append(pts, plotter.XY{X: x, Y: f(x)}) @@ -23,17 +23,23 @@ func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *i log.Fatalf("Failed to NewLine: %v", err) } fn.Color = color.RGBA{B: 255, A: 255} - return func(x float64) *image.RGBA { + return func(x float64, isMinimum bool) *image.RGBA { pts := plotter.XYs{{X: x, Y: f(x)}} xScatter, err := plotter.NewScatter(pts) if err != nil { log.Fatalf("Failed to NewScatter: %v", err) } - xScatter.Color = color.RGBA{R: 255, A: 255} + if isMinimum { + xScatter.Color = color.RGBA{G: 255, A: 255} + + } else { + xScatter.Color = color.RGBA{R: 255, A: 255} + + } labels, err := plotter.NewLabels(plotter.XYLabels{ XYs: pts, - Labels: []string{fmt.Sprintf("x = %.2f", x)}, + Labels: []string{fmt.Sprintf("x = %.5f", x)}, }) labels.Offset = vg.Point{X: -10, Y: 15} if err != nil { diff --git a/point.go b/point.go new file mode 100644 index 0000000..b934459 --- /dev/null +++ b/point.go @@ -0,0 +1,15 @@ +package main + +import "math/rand" + +type point struct { + x, y float64 +} + +func NewRandomPoints(num int) []point { + points := make([]point, num) + for i := 0; i < num; i++ { + points[i] = point{rand.Float64() * 300, rand.Float64() * 300} + } + return points +}