From 8b04f810090aa40002e4a2240e6bea7ea3a9c417 Mon Sep 17 00:00:00 2001 From: D0LeD Date: Wed, 17 Jan 2024 21:53:19 +0200 Subject: [PATCH] without drawing --- main.go | 82 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/main.go b/main.go index e39985b..caf39e0 100644 --- a/main.go +++ b/main.go @@ -1,33 +1,69 @@ package main -import ( - "image" - "log" - "time" +import "fmt" - "github.com/hajimehoshi/ebiten/v2" -) +type functiontype struct { + k, b float64 +} -func f(x float64) float64 { return x*x + 5*x - 3 } -func df(x float64) float64 { return 2*x + 5 } +type coord struct { + x, y float64 +} func main() { - ebiten.SetWindowSize(640, 480) - ebiten.SetWindowTitle("Gradient descent") - - img := make(chan *image.RGBA, 1) - go func() { - p := Plot(-5, 0, 0.1, f) - x := 0.0 - img <- p(x) - for i := 0; i < 50; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.1 - img <- p(x) + points := []coord{{0, 1}, {-0.5, 0}, {0.5, 2}, {1, 3}, {2.5, 6}} + learnRate := 0.01 + epochs := 1000 + + train(points, learnRate, epochs) +} + +func train(points []coord, learnRate float64, epochs int) { + n := float64(len(points)) + var line functiontype + + for epoch := 1; epoch <= epochs; epoch++ { + fxi := findFxi(points, line) + er := findEr(points, fxi) + + var sumk, sumb float64 + for i, j := range er { + sumk += j * points[i].x + sumb += j } - }() - if err := ebiten.RunGame(&App{Img: img}); err != nil { - log.Fatal(err) + line.k += (2 / n) * sumk * learnRate + line.b += (2 / n) * sumb * learnRate + + loss := average(er) + + if epoch%100 == 0 { + fmt.Printf("Epoch %d, Loss: %f\n", epoch, loss) + } } } + +func findFxi(points []coord, line functiontype) (fxi []float64) { + for _, j := range points { + fxi = append(fxi, line.k*j.x+line.b) + } + return fxi +} + +func findEr(points []coord, fxi []float64) (er []float64) { + for i, j := range points { + er = append(er, j.y-fxi[i]) + } + return er +} + +func sliceSum(s []float64) (sum float64) { + for _, j := range s { + sum += j + } + return sum +} + +func average(s []float64) float64 { + return sliceSum(s) / float64(len(s)) +}