Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/prog-1/gradient-descent
module gradient-descent

go 1.21.1
go 1.21.6

require (
github.com/hajimehoshi/ebiten/v2 v2.6.3
Expand Down
141 changes: 126 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,142 @@
package main

import (
"fmt"
"image"
"log"
"time"
"math/rand"

"github.com/hajimehoshi/ebiten/v2"
"gonum.org/v1/plot/plotter"
)

func f(x float64) float64 { return x*x + 5*x - 3 }
func df(x float64) float64 { return 2*x + 5 }
const (
screenWidth, screenHeight = 720, 480
randMin, randMax = -10, 10
epochs, lr = 1000000, 0.000005
plotMinX, plotMaxX, plotMinY, plotMaxY = -10, 10, -50, 100 // Min and Max data values along both axis
pointMinYOffset, pointMaxYOffset, pointCount = -20, 20, 10
)

// Function points are spawed along
func f(x float64) float64 {
// return 0.5*x + 2
return 10*x - 5
}

// Inference for 1 argument(x)
func i(x, w, b float64) float64 { return w*x + b }

// Runs model on all the input data
func inference(x []float64, w, b float64) (out []float64) {
for _, v := range x {
out = append(out, i(v, w, b))
}
return
}

func loss(labels, y []float64) float64 {
var errSum float64
for i := range labels {
errSum += (y[i] - labels[i]) * (y[i] - labels[i])
}
return errSum / float64(len(labels)) // For the sake of making numbers smaller -> better percievable
}

func gradient(labels, y, x []float64) (dw, db float64) {
// dw, db - Parial derivatives, w - weight, b - bias
for i := 0; i < len(labels); i++ {
dif := y[i] - labels[i]
dw += dif * x[i]
db += dif
}
n := float64(len(labels))
dw *= 2 / n
db *= 2 / n

return
}

func train(epochs int, inputs, labels []float64) (w, b float64) {
randFloat64 := func() float64 {
return randMin + rand.Float64()*(randMax-randMin)
}
w, b = randFloat64(), randFloat64()
// w, b = 1, 0
var dw, db float64
for i := 0; i < epochs; i++ {
dw, db = gradient(labels, inference(inputs, w, b), inputs)
w -= dw * lr
b -= db * lr
fmt.Println(w, b)
}
return
}

// Returns random points along f with random Yoffset
func randPoints(f func(float64) float64, minYoffset float64, maxYoffset float64, pointCount uint) (xs, labels []float64) {
// 1. Getting random argument value X
// 2. Getting function value(Y)
// 3. Applying offset to Y
for i := uint(0); i < pointCount; i++ {
x := plotMinX + rand.Float64()*(plotMaxX-plotMinX) // Random argument within visible range
yOffset := minYoffset + rand.Float64()*(maxYoffset-minYoffset)
xs = append(xs, x)
labels = append(labels, f(x)+yOffset)
}
return
}

// func main() {
// xs, labels := randPoints(f, pointMinYOffset, pointMaxYOffset, pointCount)
// w, b := train(epochs, labels, xs)
//
// img := make(chan *image.RGBA, 1)
// go func() {
// p := Plot(plotMinX, plotMaxX, )
// }
// if err := ebiten.RunGame(&App{Img: img}); err != nil {
// log.Fatal(err)
// }
// }

// func main() {
// var loss plotter.XYs
//
// for i := 0; i < epochs; i++ {
// y := inference(inputs, w, b)
// loss = append(loss, plotter.XY{
// X: float64(i),
// Y: msl(labels, y),
// })
// lossLines, _ := plotter.NewLine(loss)
// if plotLoss {
// select {
// case img <- Plot(lossLines):
// default:
// }
// } else {
// const extra = (inputPointsMaxX - inputPointsMinX) / 10
// xs := []float64{inputPointsMinX - extra, inputPointsMaxX + extra}
// ys := inference(xs, w, b)
// resLine, _ := plotter.NewLine(plotter.XYs{{X: xs[0], Y: ys[0]}, {X: xs[1], Y: ys[1]}})
// img <- Plot(inputsScatter, resLine)
// }

func main() {
ebiten.SetWindowSize(640, 480)
ebiten.SetWindowTitle("Gradient descent")
inputs, labels := randPoints(f, pointMinYOffset, pointMaxYOffset, pointCount)
var points plotter.XYs
for i := 0; i < pointCount; i++ {
points = append(points, plotter.XY{X: inputs[i], Y: labels[i]})
}

img := make(chan *image.RGBA, 1)
go func() {
p := Plot(-5, 0, 0.1, f)
x := 0.0
img <- p(x)
for i := 0; i < 50; i++ {
time.Sleep(30 * time.Millisecond)
x -= df(x) * 0.1
img <- p(x)
}
}()
pointsScatter, _ := plotter.NewScatter(points)
fp := plotter.NewFunction(f) // f plot
w, b := train(epochs, inputs, labels)
fmt.Println(w, b)
ap := plotter.NewFunction(func(x float64) float64 { return w*x + b }) // approximating function plot
img <- Plot(pointsScatter, fp, ap)

if err := ebiten.RunGame(&App{Img: img}); err != nil {
log.Fatal(err)
Expand Down
105 changes: 59 additions & 46 deletions plot.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,73 @@
package main

import (
"fmt"
"image"
"image/color"
"log"

"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"gonum.org/v1/plot/vg/draw"
"gonum.org/v1/plot/vg/vgimg"
)

func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA {
var pts plotter.XYs
for x := xmin; x <= xmax; x += xstep {
pts = append(pts, plotter.XY{X: x, Y: f(x)})
}
fn, err := plotter.NewLine(pts)
if err != nil {
log.Fatalf("Failed to NewLine: %v", err)
}
fn.Color = color.RGBA{B: 255, A: 255}
return func(x float64) *image.RGBA {
pts := plotter.XYs{{X: x, Y: f(x)}}
xScatter, err := plotter.NewScatter(pts)
if err != nil {
log.Fatalf("Failed to NewScatter: %v", err)
}
xScatter.Color = color.RGBA{R: 255, A: 255}
//// Old one
// func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA {
// var pts plotter.XYs
// for x := xmin; x <= xmax; x += xstep {
// pts = append(pts, plotter.XY{X: x, Y: f(x)})
// }
// fn, err := plotter.NewLine(pts)
// if err != nil {
// log.Fatalf("Failed to NewLine: %v", err)
// }
// fn.Color = color.RGBA{B: 255, A: 255}
// return func(x float64) *image.RGBA {
// pts := plotter.XYs{{X: x, Y: f(x)}}
// xScatter, err := plotter.NewScatter(pts)
// if err != nil {
// log.Fatalf("Failed to NewScatter: %v", err)
// }
// xScatter.Color = color.RGBA{R: 255, A: 255}
//
// labels, err := plotter.NewLabels(plotter.XYLabels{
// XYs: pts,
// Labels: []string{fmt.Sprintf("x = %.2f", x)},
// })
// labels.Offset = vg.Point{X: -10, Y: 15}
// if err != nil {
// log.Fatalf("Failed to NewLabels: %v", err)
// }
//
// p := plot.New()
// p.Add(
// plotter.NewGrid(),
// fn,
// xScatter,
// labels,
// )
// p.Legend.Add("f(x)", fn)
// p.Legend.Add("x", xScatter)
// p.X.Label.Text = "X"
// p.Y.Label.Text = "Y"
//
// img := image.NewRGBA(image.Rect(0, 0, screenWidth, screenHeight))
// c := vgimg.NewWith(vgimg.UseImage(img))
// p.Draw(draw.New(c))
// return c.Image().(*image.RGBA)
// }
// }

labels, err := plotter.NewLabels(plotter.XYLabels{
XYs: pts,
Labels: []string{fmt.Sprintf("x = %.2f", x)},
})
labels.Offset = vg.Point{X: -10, Y: 15}
if err != nil {
log.Fatalf("Failed to NewLabels: %v", err)
}
func Plot(ps ...plot.Plotter) *image.RGBA {
p := plot.New()
p.X.Min = plotMinX
p.X.Max = plotMaxX
p.Y.Min = plotMinY
p.Y.Max = plotMaxY

p := plot.New()
p.Add(
plotter.NewGrid(),
fn,
xScatter,
labels,
)
p.Legend.Add("f(x)", fn)
p.Legend.Add("x", xScatter)
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"

img := image.NewRGBA(image.Rect(0, 0, 640, 480))
c := vgimg.NewWith(vgimg.UseImage(img))
p.Draw(draw.New(c))
return c.Image().(*image.RGBA)
}
p.Add(append([]plot.Plotter{
plotter.NewGrid(),
}, ps...)...)
img := image.NewRGBA(image.Rect(0, 0, screenWidth, screenHeight))
c := vgimg.NewWith(vgimg.UseImage(img))
p.Draw(draw.New(c))
return c.Image().(*image.RGBA)
}