From 307ff2abfb83ff63aeba074d36d855d757942509 Mon Sep 17 00:00:00 2001 From: Richard Shashalevich Date: Wed, 10 Jan 2024 21:44:12 +0200 Subject: [PATCH 1/4] =?UTF-8?q?=E2=A0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 36 ++++++++++++++++++++++++++++-------- plot.go | 14 ++++++++++---- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index e39985b..269d517 100644 --- a/main.go +++ b/main.go @@ -8,26 +8,46 @@ import ( "github.com/hajimehoshi/ebiten/v2" ) -func f(x float64) float64 { return x*x + 5*x - 3 } +const ( + e = 1e-5 +) + +func f(x float64) float64 { return 0.01*x*x*x*x + x*x + 5*x - 3 } func df(x float64) float64 { return 2*x + 5 } +func Derivative(x float64, f func(x float64) float64) float64 { + return (f(x+e) - f(x)) / e +} + func main() { ebiten.SetWindowSize(640, 480) ebiten.SetWindowTitle("Gradient descent") img := make(chan *image.RGBA, 1) go func() { - p := Plot(-5, 0, 0.1, f) - x := 0.0 - img <- p(x) - for i := 0; i < 50; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.1 - img <- p(x) + p := Plot(-10, 10, 0.1, f) + x := 10.0 + img <- p(x, false) + for { + time.Sleep(300 * time.Millisecond) + x -= Derivative(x, f) * 0.3 + img <- p(x, false) + if Abs(Derivative(x, f)) < e { + img <- p(x, true) + break + } } + }() if err := ebiten.RunGame(&App{Img: img}); err != nil { log.Fatal(err) } } + +func Abs(a float64) float64 { + if a < 0 { + return a * -1 + } + return a +} diff --git a/plot.go b/plot.go index ebad8a9..b6eae71 100644 --- a/plot.go +++ b/plot.go @@ -13,7 +13,7 @@ import ( "gonum.org/v1/plot/vg/vgimg" ) -func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA { +func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64, isMinimum bool) *image.RGBA { var pts plotter.XYs for x := xmin; x <= xmax; x += xstep { pts = append(pts, plotter.XY{X: x, Y: f(x)}) @@ -23,17 +23,23 @@ func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *i log.Fatalf("Failed to NewLine: %v", err) } fn.Color = color.RGBA{B: 255, A: 255} - return func(x float64) *image.RGBA { + return func(x float64, isMinimum bool) *image.RGBA { pts := plotter.XYs{{X: x, Y: f(x)}} xScatter, err := plotter.NewScatter(pts) if err != nil { log.Fatalf("Failed to NewScatter: %v", err) } - xScatter.Color = color.RGBA{R: 255, A: 255} + if isMinimum { + xScatter.Color = color.RGBA{G: 255, A: 255} + + } else { + xScatter.Color = color.RGBA{R: 255, A: 255} + + } labels, err := plotter.NewLabels(plotter.XYLabels{ XYs: pts, - Labels: []string{fmt.Sprintf("x = %.2f", x)}, + Labels: []string{fmt.Sprintf("x = %.5f", x)}, }) labels.Offset = vg.Point{X: -10, Y: 15} if err != nil { From 27e376044cd0573a6dc77e52fb3c1b42637ee382 Mon Sep 17 00:00:00 2001 From: Richard Shashalevich Date: Thu, 11 Jan 2024 09:57:07 +0200 Subject: [PATCH 2/4] some fixes --- main.go | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/main.go b/main.go index 269d517..282cd88 100644 --- a/main.go +++ b/main.go @@ -12,12 +12,13 @@ const ( e = 1e-5 ) -func f(x float64) float64 { return 0.01*x*x*x*x + x*x + 5*x - 3 } -func df(x float64) float64 { return 2*x + 5 } +func f(x float64) float64 { return x*x*x*x + x*x + 5*x - 3 } -func Derivative(x float64, f func(x float64) float64) float64 { - return (f(x+e) - f(x)) / e -} +func df(x float64) float64 { return 4*x*x*x + 2*x + 5 } + +// func Derivative(x float64, f func(x float64) float64) float64 { +// return (f(x+e) - f(x)) / e +// } func main() { ebiten.SetWindowSize(640, 480) @@ -26,16 +27,12 @@ func main() { img := make(chan *image.RGBA, 1) go func() { p := Plot(-10, 10, 0.1, f) - x := 10.0 + x := 5.0 img <- p(x, false) - for { - time.Sleep(300 * time.Millisecond) - x -= Derivative(x, f) * 0.3 + for i := 0; i < 5000; i++ { + time.Sleep(30 * time.Millisecond) + x -= df(x) * 0.001 img <- p(x, false) - if Abs(Derivative(x, f)) < e { - img <- p(x, true) - break - } } }() From 08e1b4552b4bd52bac7118f9bc06fe3837e662f7 Mon Sep 17 00:00:00 2001 From: Richard Shashalevich Date: Wed, 17 Jan 2024 21:48:05 +0200 Subject: [PATCH 3/4] hw --- Line.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ main.go | 17 +++++++++++------ point.go | 15 +++++++++++++++ 3 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 Line.go create mode 100644 point.go diff --git a/Line.go b/Line.go new file mode 100644 index 0000000..2955b2e --- /dev/null +++ b/Line.go @@ -0,0 +1,44 @@ +package main + +import ( + "fmt" + "log" + "math" +) + +type line struct { + k float64 + b float64 +} + +func (l *line) y(x float64) float64 { + return l.k*x + l.b +} + +func NewLine() *line { + return &line{0, 0} +} + +func (l *line) Train(points []point, lr float64, epochs uint) error { + for i := uint(0); i < epochs; i++ { + y := make([]float64, len(points)) + for j := range points { + y[j] = l.y(points[j].x) + } + + var sum1, sum2 float64 + for j := range points { + sum1 += y[j] - points[j].y + sum2 += y[j] - points[j].y*points[j].x + } + + l.k -= lr * (2 / float64(len(points))) * sum2 + l.b -= lr * (2 / float64(len(points))) * sum1 + loss := sum1 / float64(len(points)) + if math.IsNaN(loss) || math.IsInf(loss, 0) { + return fmt.Errorf("training unsuccessful") + } + log.Printf("Epoch: %d/%d, Loss %f\n", i, epochs, loss) + } + return nil +} diff --git a/main.go b/main.go index 282cd88..bcb6f3c 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "image" "log" - "time" "github.com/hajimehoshi/ebiten/v2" ) @@ -26,12 +25,18 @@ func main() { img := make(chan *image.RGBA, 1) go func() { - p := Plot(-10, 10, 0.1, f) - x := 5.0 + points := NewRandomPoints(10) + l := NewLine() + err := l.Train(points, 0.01, 5000) + if err != nil { + log.Fatal(err) + } + p := Plot(-30, 30, 0.1, l.y) + x := 30.0 img <- p(x, false) - for i := 0; i < 5000; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.001 + for i := 0; i < 500000; i++ { + // time.Sleep( * time.Millisecond) + x -= df(x) * 1e-6 img <- p(x, false) } diff --git a/point.go b/point.go new file mode 100644 index 0000000..b934459 --- /dev/null +++ b/point.go @@ -0,0 +1,15 @@ +package main + +import "math/rand" + +type point struct { + x, y float64 +} + +func NewRandomPoints(num int) []point { + points := make([]point, num) + for i := 0; i < num; i++ { + points[i] = point{rand.Float64() * 300, rand.Float64() * 300} + } + return points +} From 25d6860d134004c755d26778001cc5d812f09159 Mon Sep 17 00:00:00 2001 From: Richard Shashalevich Date: Wed, 17 Jan 2024 21:48:28 +0200 Subject: [PATCH 4/4] hw --- main.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index bcb6f3c..92d5f2b 100644 --- a/main.go +++ b/main.go @@ -20,18 +20,18 @@ func df(x float64) float64 { return 4*x*x*x + 2*x + 5 } // } func main() { + p := NewRandomPoints(10) + l := NewLine() + err := l.Train(p, 0.001, 5000) + if err != nil { + log.Fatal(err) + } ebiten.SetWindowSize(640, 480) ebiten.SetWindowTitle("Gradient descent") img := make(chan *image.RGBA, 1) go func() { - points := NewRandomPoints(10) - l := NewLine() - err := l.Train(points, 0.01, 5000) - if err != nil { - log.Fatal(err) - } - p := Plot(-30, 30, 0.1, l.y) + p := Plot(-30, 30, 0.1, f) x := 30.0 img <- p(x, false) for i := 0; i < 500000; i++ {