Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions Line.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"fmt"
"log"
"math"
)

type line struct {
k float64
b float64
}

func (l *line) y(x float64) float64 {
return l.k*x + l.b
}

func NewLine() *line {
return &line{0, 0}
}

func (l *line) Train(points []point, lr float64, epochs uint) error {
for i := uint(0); i < epochs; i++ {
y := make([]float64, len(points))
for j := range points {
y[j] = l.y(points[j].x)
}

var sum1, sum2 float64
for j := range points {
sum1 += y[j] - points[j].y
sum2 += y[j] - points[j].y*points[j].x
}

l.k -= lr * (2 / float64(len(points))) * sum2
l.b -= lr * (2 / float64(len(points))) * sum1
loss := sum1 / float64(len(points))
if math.IsNaN(loss) || math.IsInf(loss, 0) {
return fmt.Errorf("training unsuccessful")
}
log.Printf("Epoch: %d/%d, Loss %f\n", i, epochs, loss)
}
return nil
}
42 changes: 32 additions & 10 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,53 @@ package main
import (
"image"
"log"
"time"

"github.com/hajimehoshi/ebiten/v2"
)

func f(x float64) float64 { return x*x + 5*x - 3 }
func df(x float64) float64 { return 2*x + 5 }
const (
e = 1e-5
)

func f(x float64) float64 { return x*x*x*x + x*x + 5*x - 3 }

func df(x float64) float64 { return 4*x*x*x + 2*x + 5 }

// func Derivative(x float64, f func(x float64) float64) float64 {
// return (f(x+e) - f(x)) / e
// }

func main() {
p := NewRandomPoints(10)
l := NewLine()
err := l.Train(p, 0.001, 5000)
if err != nil {
log.Fatal(err)
}
ebiten.SetWindowSize(640, 480)
ebiten.SetWindowTitle("Gradient descent")

img := make(chan *image.RGBA, 1)
go func() {
p := Plot(-5, 0, 0.1, f)
x := 0.0
img <- p(x)
for i := 0; i < 50; i++ {
time.Sleep(30 * time.Millisecond)
x -= df(x) * 0.1
img <- p(x)
p := Plot(-30, 30, 0.1, f)
x := 30.0
img <- p(x, false)
for i := 0; i < 500000; i++ {
// time.Sleep( * time.Millisecond)
x -= df(x) * 1e-6
img <- p(x, false)
}

}()

if err := ebiten.RunGame(&App{Img: img}); err != nil {
log.Fatal(err)
}
}

func Abs(a float64) float64 {
if a < 0 {
return a * -1
}
return a
}
14 changes: 10 additions & 4 deletions plot.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"gonum.org/v1/plot/vg/vgimg"
)

func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA {
func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64, isMinimum bool) *image.RGBA {
var pts plotter.XYs
for x := xmin; x <= xmax; x += xstep {
pts = append(pts, plotter.XY{X: x, Y: f(x)})
Expand All @@ -23,17 +23,23 @@ func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *i
log.Fatalf("Failed to NewLine: %v", err)
}
fn.Color = color.RGBA{B: 255, A: 255}
return func(x float64) *image.RGBA {
return func(x float64, isMinimum bool) *image.RGBA {
pts := plotter.XYs{{X: x, Y: f(x)}}
xScatter, err := plotter.NewScatter(pts)
if err != nil {
log.Fatalf("Failed to NewScatter: %v", err)
}
xScatter.Color = color.RGBA{R: 255, A: 255}
if isMinimum {
xScatter.Color = color.RGBA{G: 255, A: 255}

} else {
xScatter.Color = color.RGBA{R: 255, A: 255}

}

labels, err := plotter.NewLabels(plotter.XYLabels{
XYs: pts,
Labels: []string{fmt.Sprintf("x = %.2f", x)},
Labels: []string{fmt.Sprintf("x = %.5f", x)},
})
labels.Offset = vg.Point{X: -10, Y: 15}
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions point.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package main

import "math/rand"

type point struct {
x, y float64
}

func NewRandomPoints(num int) []point {
points := make([]point, num)
for i := 0; i < num; i++ {
points[i] = point{rand.Float64() * 300, rand.Float64() * 300}
}
return points
}