Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions app.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
package main

import (
"image"

"github.com/hajimehoshi/ebiten/v2"
"gonum.org/v1/plot"
)

const (
sW = 1250
sH = 720
)

type App struct {
Img <-chan *image.RGBA
img *ebiten.Image
width, height int //screen width & height
plot *plot.Plot //global access to plot
}

func (app *App) Update() error { return nil }
func (a *App) Update() error {
return nil
}

func (app *App) Draw(screen *ebiten.Image) {
select {
case img := <-app.Img:
app.img = ebiten.NewImageFromImage(img)
default:
}
if app.img != nil {
screen.DrawImage(app.img, nil)
func (a *App) Draw(screen *ebiten.Image) {
if a.plot != nil { //to avoid crash at the start
screen.DrawImage(PlotToImage(a.plot), &ebiten.DrawImageOptions{}) //drawing plot
}
}

func (app *App) Layout(outsideWidth, outsideHeight int) (screenWidth, screenHeight int) {
return outsideWidth, outsideHeight
func (a *App) Layout(inWidth, inHeight int) (outWidth, outHeight int) {
return a.width, a.height
}

func NewApp(width, height int) *App {
return &App{width: width, height: height}
}
16 changes: 11 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
module github.com/prog-1/gradient-descent
module main.go

go 1.21.1

require (
github.com/hajimehoshi/ebiten/v2 v2.6.3
gonum.org/v1/plot v0.14.0
)
require github.com/hajimehoshi/ebiten/v2 v2.6.3

require (
gioui.org v0.2.0 // indirect
gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7 // indirect
gioui.org/shader v1.0.6 // indirect
gioui.org/x v0.2.0 // indirect
git.sr.ht/~sbinet/gg v0.5.0 // indirect
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect
github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0 // indirect
github.com/campoy/embedmd v1.0.0 // indirect
github.com/ebitengine/purego v0.5.0 // indirect
github.com/go-fonts/liberation v0.3.1 // indirect
github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9 // indirect
github.com/go-pdf/fpdf v0.8.0 // indirect
github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/jezek/xgb v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b // indirect
golang.org/x/exp/shiny v0.0.0-20230817173708-d852ddb80c63 // indirect
golang.org/x/image v0.12.0 // indirect
golang.org/x/mobile v0.0.0-20230922142353-e2f452493d57 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
gonum.org/v1/plot v0.14.0 // indirect
rsc.io/pdf v0.1.1 // indirect
)
21 changes: 13 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
git.sr.ht/~sbinet/cmpimg v0.1.0 h1:E0zPRk2muWuCqSKSVZIWsgtU9pjsw3eKHi8VmQeScxo=
git.sr.ht/~sbinet/cmpimg v0.1.0/go.mod h1:FU12psLbF4TfNXkKH2ZZQ29crIqoiqTZmeQ7dkp/pxE=
gioui.org v0.2.0 h1:RbzDn1h/pCVf/q44ImQSa/J3MIFpY3OWphzT/Tyei+w=
gioui.org v0.2.0/go.mod h1:1H72sKEk/fNFV+l0JNeM2Dt3co3Y4uaQcD+I+/GQ0e4=
gioui.org/cpu v0.0.0-20210808092351-bfe733dd3334/go.mod h1:A8M0Cn5o+vY5LTMlnRoK3O5kG+rH0kWfJjeKd9QpBmQ=
gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7 h1:tNJdnP5CgM39PRc+KWmBRRYX/zJ+rd5XaYxY5d5veqA=
gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7/go.mod h1:A8M0Cn5o+vY5LTMlnRoK3O5kG+rH0kWfJjeKd9QpBmQ=
gioui.org/shader v1.0.6 h1:cvZmU+eODFR2545X+/8XucgZdTtEjR3QWW6W65b0q5Y=
gioui.org/shader v1.0.6/go.mod h1:mWdiME581d/kV7/iEhLmUgUK5iZ09XR5XpduXzbePVM=
gioui.org/x v0.2.0 h1:/MbdjKH19F16auv19UiQxli2n6BYPw7eyh9XBOTgmEw=
gioui.org/x v0.2.0/go.mod h1:rCGN2nZ8ZHqrtseJoQxCMZpt2xrZUrdZ2WuMRLBJmYs=
git.sr.ht/~sbinet/gg v0.5.0 h1:6V43j30HM623V329xA9Ntq+WJrMjDxRjuAB1LFWF5m8=
git.sr.ht/~sbinet/gg v0.5.0/go.mod h1:G2C0eRESqlKhS7ErsNey6HHrqU1PwsnCQlekFi9Q2Oo=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY=
github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk=
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw=
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0 h1:uF5Q/hWnDU1XZeT6CsrRSxHLroUSEYYO3kgES+yd+So=
github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0/go.mod h1:ccdDYaY5+gO+cbnQdFxEXqfy0RkoV25H3jLXUDNM3wg=
github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY=
github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8=
github.com/ebitengine/purego v0.5.0 h1:JrMGKfRIAM4/QVKaesIIT7m/UVjTj5GYhRSQYwfVdpo=
github.com/ebitengine/purego v0.5.0/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
github.com/go-fonts/dejavu v0.1.0 h1:JSajPXURYqpr+Cu8U9bt8K+XcACIHWqWrvWCKyeFmVQ=
github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g=
github.com/go-fonts/latin-modern v0.3.1 h1:/cT8A7uavYKvglYXvrdDw4oS5ZLkcOU22fa2HJ1/JVM=
github.com/go-fonts/latin-modern v0.3.1/go.mod h1:ysEQXnuT/sCDOAONxC7ImeEDVINbltClhasMAqEtRK0=
github.com/go-fonts/liberation v0.3.1 h1:9RPT2NhUpxQ7ukUvz3jeUckmN42T9D9TpjtQcqK/ceM=
github.com/go-fonts/liberation v0.3.1/go.mod h1:jdJ+cqF+F4SUL2V+qxBth8fvBpBDS7yloUL5Fi8GTGY=
github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9 h1:NxXI5pTAtpEaU49bpLpQoDsu1zrteW/vxzTz8Cd2UAs=
github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9/go.mod h1:gWuR/CrFDDeVRFQwHPvsv9soJVB/iqymhuZQuJ3a9OM=
github.com/go-pdf/fpdf v0.8.0 h1:IJKpdaagnWUeSkUFUjTcSzTppFxmv8ucGQyNPQWxYOQ=
github.com/go-pdf/fpdf v0.8.0/go.mod h1:gfqhcNwXrsd3XYKte9a7vM3smvU/jB4ZRDrmWSxpfdc=
github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb h1:4GpJirtA8yY24aqbU3uppiXGYiVpWfLIrqc2NNKKk9s=
github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb/go.mod h1:evDBbvNR/KaVFZ2ZlDSOWWXIUKq0wCOEtzLxRM8SG3k=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/hajimehoshi/ebiten/v2 v2.6.3 h1:xJ5klESxhflZbPUx3GdIPoITzgPgamsyv8aZCVguXGI=
Expand Down Expand Up @@ -87,8 +94,6 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
gonum.org/v1/plot v0.14.0 h1:+LBDVFYwFe4LHhdP8coW6296MBEY4nQ+Y4vuUpJopcE=
gonum.org/v1/plot v0.14.0/go.mod h1:MLdR9424SJed+5VqC6MsouEpig9pZX2VZ57H9ko2bXU=
honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las=
Expand Down
62 changes: 62 additions & 0 deletions linearRegression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package main

import (
"fmt"
"time"
)

const (
lrB = 0.1
lrK = 0.0001
epochs = 1000
)

//#######################################################################

func loss(k, b float64, px, py []float64) float64 {
totalE := 0.0 // error
for i := range px {
x := px[i]
y := py[i]
totalE += (y - (k*x + b)) * 2
}
totalE /= float64(len(px))

return totalE
}

func inference(x, k, b float64) float64 {
return k*x + b
}

func gradientDescent(k, b float64, px, py []float64, epoch int) (float64, float64) {
dk, db := 0.0, 0.0 // gradients for coefficients
n := float64(len(px))
for i := range px {
x := px[i]
y := py[i]
dk -= (2 / n) * (y - (k*x + b)) * x
db -= (2 / n) * (y - (k*x + b))
}
k -= dk * lrK
b -= db * lrB
if epoch%100 == 0 {
fmt.Println("dk:", dk, "db:", db, "\n")
}
return k, b
}

func (a *App) linearRegression(px, py []float64) (k, b float64) {
for epoch := 1; epoch <= epochs; epoch++ {
if epoch%100 == 0 {
fmt.Println("Epoch:", epoch, "Loss:", loss(k, b, px, py))
}
k, b = gradientDescent(k, b, px, py, epoch)
a.updatePlot(k, b, px, py)
time.Sleep(time.Millisecond)
}

return k, b
}

//#######################################################################
46 changes: 30 additions & 16 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,47 @@
package main

import (
"image"
"log"
"time"
"math/rand"

"github.com/hajimehoshi/ebiten/v2"
)

func f(x float64) float64 { return x*x + 5*x - 3 }
func df(x float64) float64 { return 2*x + 5 }
const (
numberOfPoints = 20
pointMin, pointMax = 30, 70 //point distribution
lineMin, lineMax = 0, 100 //line lenght
)

func main() {
ebiten.SetWindowSize(640, 480)
ebiten.SetWindowTitle("Gradient descent")

img := make(chan *image.RGBA, 1)
//####################### Linear Regression #########################

//Generating random points
px := make([]float64, numberOfPoints)
py := make([]float64, numberOfPoints)
for i := 0; i < numberOfPoints; i++ {
px[i] = (rand.Float64()*(pointMax-pointMin) + pointMin)
py[i] = (rand.Float64()*(pointMax-pointMin) + pointMin)
}

//####################### Ebiten ####################################

//Window
ebiten.SetWindowSize(sW, sH)
ebiten.SetWindowTitle("Linear Regression")

//App instance
a := NewApp(sW, sH)

//
go func() {
p := Plot(-5, 0, 0.1, f)
x := 0.0
img <- p(x)
for i := 0; i < 50; i++ {
time.Sleep(30 * time.Millisecond)
x -= df(x) * 0.1
img <- p(x)
}
a.linearRegression(px, py)
}()

if err := ebiten.RunGame(&App{Img: img}); err != nil {
//Running game
if err := ebiten.RunGame(a); err != nil {
log.Fatal(err)
}

}
96 changes: 50 additions & 46 deletions plot.go
Original file line number Diff line number Diff line change
@@ -1,60 +1,64 @@
package main

import (
"fmt"
"image"
"image/color"
"log"

"github.com/hajimehoshi/ebiten/v2"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"gonum.org/v1/plot/vg/draw"
"gonum.org/v1/plot/vg/vgimg"
)

func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA {
var pts plotter.XYs
for x := xmin; x <= xmax; x += xstep {
pts = append(pts, plotter.XY{X: x, Y: f(x)})
}
fn, err := plotter.NewLine(pts)
if err != nil {
log.Fatalf("Failed to NewLine: %v", err)
// Converting plot to ebiten.Image
func PlotToImage(p *plot.Plot) *ebiten.Image {

img := image.NewRGBA(image.Rect(0, 0, sW, sH)) //creating image.RGBA to store the plot

c := vgimg.NewWith(vgimg.UseImage(img)) //creating plot drawer for the image

p.Draw(draw.New(c)) //drawing plot on the image

return ebiten.NewImageFromImage(c.Image()) //converting image.RGBA to ebiten.Image (doing in Draw)
///Black screen issue: was giving "img" instead of "c.Image()" in the function.
}

// recreating plot with given data
func (a *App) updatePlot(k, b float64, px, py []float64) {

p := plot.New() //initializing plot

//##################################################

//Line

linePoints := plotter.XYs{
{X: lineMin, Y: inference(lineMin, k, b)},
{X: lineMax, Y: inference(lineMax, k, b)},
}
fn.Color = color.RGBA{B: 255, A: 255}
return func(x float64) *image.RGBA {
pts := plotter.XYs{{X: x, Y: f(x)}}
xScatter, err := plotter.NewScatter(pts)
if err != nil {
log.Fatalf("Failed to NewScatter: %v", err)
}
xScatter.Color = color.RGBA{R: 255, A: 255}

labels, err := plotter.NewLabels(plotter.XYLabels{
XYs: pts,
Labels: []string{fmt.Sprintf("x = %.2f", x)},
})
labels.Offset = vg.Point{X: -10, Y: 15}
if err != nil {
log.Fatalf("Failed to NewLabels: %v", err)
}

p := plot.New()
p.Add(
plotter.NewGrid(),
fn,
xScatter,
labels,
)
p.Legend.Add("f(x)", fn)
p.Legend.Add("x", xScatter)
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"

img := image.NewRGBA(image.Rect(0, 0, 640, 480))
c := vgimg.NewWith(vgimg.UseImage(img))
p.Draw(draw.New(c))
return c.Image().(*image.RGBA)

line, _ := plotter.NewLine(linePoints) //creating line

p.Add(line) //adding line to the plot°

//##################################################

//Points

var points plotter.XYs //initializing point plotter

for i := 0; i < len(px); i++ {
points = append(points, plotter.XY{X: px[i], Y: py[i]}) //Saving all points in plotter
}

scatter, _ := plotter.NewScatter(points) //creating new scatter from point data

p.Add(scatter) //adding points to plot

//##################################################

//App

a.plot = p //replacing old plot with new one

}