diff --git a/app.go b/app.go index 90cb206..894d3df 100644 --- a/app.go +++ b/app.go @@ -1,29 +1,34 @@ package main import ( - "image" - "github.com/hajimehoshi/ebiten/v2" + "gonum.org/v1/plot" +) + +const ( + sW = 1250 + sH = 720 ) type App struct { - Img <-chan *image.RGBA - img *ebiten.Image + width, height int //screen width & height + plot *plot.Plot //global access to plot } -func (app *App) Update() error { return nil } +func (a *App) Update() error { + return nil +} -func (app *App) Draw(screen *ebiten.Image) { - select { - case img := <-app.Img: - app.img = ebiten.NewImageFromImage(img) - default: - } - if app.img != nil { - screen.DrawImage(app.img, nil) +func (a *App) Draw(screen *ebiten.Image) { + if a.plot != nil { //to avoid crash at the start + screen.DrawImage(PlotToImage(a.plot), &ebiten.DrawImageOptions{}) //drawing plot } } -func (app *App) Layout(outsideWidth, outsideHeight int) (screenWidth, screenHeight int) { - return outsideWidth, outsideHeight +func (a *App) Layout(inWidth, inHeight int) (outWidth, outHeight int) { + return a.width, a.height +} + +func NewApp(width, height int) *App { + return &App{width: width, height: height} } diff --git a/go.mod b/go.mod index fd11da7..113ff28 100644 --- a/go.mod +++ b/go.mod @@ -1,27 +1,33 @@ -module github.com/prog-1/gradient-descent +module main.go go 1.21.1 -require ( - github.com/hajimehoshi/ebiten/v2 v2.6.3 - gonum.org/v1/plot v0.14.0 -) +require github.com/hajimehoshi/ebiten/v2 v2.6.3 require ( + gioui.org v0.2.0 // indirect + gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7 // indirect + gioui.org/shader v1.0.6 // indirect + gioui.org/x v0.2.0 // indirect git.sr.ht/~sbinet/gg v0.5.0 // indirect github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect + github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0 // indirect github.com/campoy/embedmd v1.0.0 // indirect github.com/ebitengine/purego v0.5.0 // indirect github.com/go-fonts/liberation v0.3.1 // indirect github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9 // indirect github.com/go-pdf/fpdf v0.8.0 // indirect + github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/jezek/xgb v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b // indirect golang.org/x/exp/shiny v0.0.0-20230817173708-d852ddb80c63 // indirect golang.org/x/image v0.12.0 // indirect golang.org/x/mobile v0.0.0-20230922142353-e2f452493d57 // indirect golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect + gonum.org/v1/plot v0.14.0 // indirect + rsc.io/pdf v0.1.1 // indirect ) diff --git a/go.sum b/go.sum index 3acb0f1..064640b 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,12 @@ -git.sr.ht/~sbinet/cmpimg v0.1.0 h1:E0zPRk2muWuCqSKSVZIWsgtU9pjsw3eKHi8VmQeScxo= -git.sr.ht/~sbinet/cmpimg v0.1.0/go.mod h1:FU12psLbF4TfNXkKH2ZZQ29crIqoiqTZmeQ7dkp/pxE= +gioui.org v0.2.0 h1:RbzDn1h/pCVf/q44ImQSa/J3MIFpY3OWphzT/Tyei+w= +gioui.org v0.2.0/go.mod h1:1H72sKEk/fNFV+l0JNeM2Dt3co3Y4uaQcD+I+/GQ0e4= +gioui.org/cpu v0.0.0-20210808092351-bfe733dd3334/go.mod h1:A8M0Cn5o+vY5LTMlnRoK3O5kG+rH0kWfJjeKd9QpBmQ= +gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7 h1:tNJdnP5CgM39PRc+KWmBRRYX/zJ+rd5XaYxY5d5veqA= +gioui.org/cpu v0.0.0-20220412190645-f1e9e8c3b1f7/go.mod h1:A8M0Cn5o+vY5LTMlnRoK3O5kG+rH0kWfJjeKd9QpBmQ= +gioui.org/shader v1.0.6 h1:cvZmU+eODFR2545X+/8XucgZdTtEjR3QWW6W65b0q5Y= +gioui.org/shader v1.0.6/go.mod h1:mWdiME581d/kV7/iEhLmUgUK5iZ09XR5XpduXzbePVM= +gioui.org/x v0.2.0 h1:/MbdjKH19F16auv19UiQxli2n6BYPw7eyh9XBOTgmEw= +gioui.org/x v0.2.0/go.mod h1:rCGN2nZ8ZHqrtseJoQxCMZpt2xrZUrdZ2WuMRLBJmYs= git.sr.ht/~sbinet/gg v0.5.0 h1:6V43j30HM623V329xA9Ntq+WJrMjDxRjuAB1LFWF5m8= git.sr.ht/~sbinet/gg v0.5.0/go.mod h1:G2C0eRESqlKhS7ErsNey6HHrqU1PwsnCQlekFi9Q2Oo= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -7,20 +14,20 @@ github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0 h1:uF5Q/hWnDU1XZeT6CsrRSxHLroUSEYYO3kgES+yd+So= +github.com/andybalholm/stroke v0.0.0-20221221101821-bd29b49d73f0/go.mod h1:ccdDYaY5+gO+cbnQdFxEXqfy0RkoV25H3jLXUDNM3wg= github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY= github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= github.com/ebitengine/purego v0.5.0 h1:JrMGKfRIAM4/QVKaesIIT7m/UVjTj5GYhRSQYwfVdpo= github.com/ebitengine/purego v0.5.0/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= -github.com/go-fonts/dejavu v0.1.0 h1:JSajPXURYqpr+Cu8U9bt8K+XcACIHWqWrvWCKyeFmVQ= -github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= -github.com/go-fonts/latin-modern v0.3.1 h1:/cT8A7uavYKvglYXvrdDw4oS5ZLkcOU22fa2HJ1/JVM= -github.com/go-fonts/latin-modern v0.3.1/go.mod h1:ysEQXnuT/sCDOAONxC7ImeEDVINbltClhasMAqEtRK0= github.com/go-fonts/liberation v0.3.1 h1:9RPT2NhUpxQ7ukUvz3jeUckmN42T9D9TpjtQcqK/ceM= github.com/go-fonts/liberation v0.3.1/go.mod h1:jdJ+cqF+F4SUL2V+qxBth8fvBpBDS7yloUL5Fi8GTGY= github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9 h1:NxXI5pTAtpEaU49bpLpQoDsu1zrteW/vxzTz8Cd2UAs= github.com/go-latex/latex v0.0.0-20230307184459-12ec69307ad9/go.mod h1:gWuR/CrFDDeVRFQwHPvsv9soJVB/iqymhuZQuJ3a9OM= github.com/go-pdf/fpdf v0.8.0 h1:IJKpdaagnWUeSkUFUjTcSzTppFxmv8ucGQyNPQWxYOQ= github.com/go-pdf/fpdf v0.8.0/go.mod h1:gfqhcNwXrsd3XYKte9a7vM3smvU/jB4ZRDrmWSxpfdc= +github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb h1:4GpJirtA8yY24aqbU3uppiXGYiVpWfLIrqc2NNKKk9s= +github.com/go-text/typesetting v0.0.0-20230905121921-abdbcca6e0eb/go.mod h1:evDBbvNR/KaVFZ2ZlDSOWWXIUKq0wCOEtzLxRM8SG3k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/hajimehoshi/ebiten/v2 v2.6.3 h1:xJ5klESxhflZbPUx3GdIPoITzgPgamsyv8aZCVguXGI= @@ -87,8 +94,6 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= -gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= gonum.org/v1/plot v0.14.0 h1:+LBDVFYwFe4LHhdP8coW6296MBEY4nQ+Y4vuUpJopcE= gonum.org/v1/plot v0.14.0/go.mod h1:MLdR9424SJed+5VqC6MsouEpig9pZX2VZ57H9ko2bXU= honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/linearRegression.go b/linearRegression.go new file mode 100644 index 0000000..815973e --- /dev/null +++ b/linearRegression.go @@ -0,0 +1,62 @@ +package main + +import ( + "fmt" + "time" +) + +const ( + lrB = 0.1 + lrK = 0.0001 + epochs = 1000 +) + +//####################################################################### + +func loss(k, b float64, px, py []float64) float64 { + totalE := 0.0 // error + for i := range px { + x := px[i] + y := py[i] + totalE += (y - (k*x + b)) * 2 + } + totalE /= float64(len(px)) + + return totalE +} + +func inference(x, k, b float64) float64 { + return k*x + b +} + +func gradientDescent(k, b float64, px, py []float64, epoch int) (float64, float64) { + dk, db := 0.0, 0.0 // gradients for coefficients + n := float64(len(px)) + for i := range px { + x := px[i] + y := py[i] + dk -= (2 / n) * (y - (k*x + b)) * x + db -= (2 / n) * (y - (k*x + b)) + } + k -= dk * lrK + b -= db * lrB + if epoch%100 == 0 { + fmt.Println("dk:", dk, "db:", db, "\n") + } + return k, b +} + +func (a *App) linearRegression(px, py []float64) (k, b float64) { + for epoch := 1; epoch <= epochs; epoch++ { + if epoch%100 == 0 { + fmt.Println("Epoch:", epoch, "Loss:", loss(k, b, px, py)) + } + k, b = gradientDescent(k, b, px, py, epoch) + a.updatePlot(k, b, px, py) + time.Sleep(time.Millisecond) + } + + return k, b +} + +//####################################################################### diff --git a/main.go b/main.go index e39985b..6952152 100644 --- a/main.go +++ b/main.go @@ -1,33 +1,47 @@ package main import ( - "image" "log" - "time" + "math/rand" "github.com/hajimehoshi/ebiten/v2" ) -func f(x float64) float64 { return x*x + 5*x - 3 } -func df(x float64) float64 { return 2*x + 5 } +const ( + numberOfPoints = 20 + pointMin, pointMax = 30, 70 //point distribution + lineMin, lineMax = 0, 100 //line lenght +) func main() { - ebiten.SetWindowSize(640, 480) - ebiten.SetWindowTitle("Gradient descent") - img := make(chan *image.RGBA, 1) + //####################### Linear Regression ######################### + + //Generating random points + px := make([]float64, numberOfPoints) + py := make([]float64, numberOfPoints) + for i := 0; i < numberOfPoints; i++ { + px[i] = (rand.Float64()*(pointMax-pointMin) + pointMin) + py[i] = (rand.Float64()*(pointMax-pointMin) + pointMin) + } + + //####################### Ebiten #################################### + + //Window + ebiten.SetWindowSize(sW, sH) + ebiten.SetWindowTitle("Linear Regression") + + //App instance + a := NewApp(sW, sH) + + // go func() { - p := Plot(-5, 0, 0.1, f) - x := 0.0 - img <- p(x) - for i := 0; i < 50; i++ { - time.Sleep(30 * time.Millisecond) - x -= df(x) * 0.1 - img <- p(x) - } + a.linearRegression(px, py) }() - if err := ebiten.RunGame(&App{Img: img}); err != nil { + //Running game + if err := ebiten.RunGame(a); err != nil { log.Fatal(err) } + } diff --git a/plot.go b/plot.go index ebad8a9..1f99b98 100644 --- a/plot.go +++ b/plot.go @@ -1,60 +1,64 @@ package main import ( - "fmt" "image" - "image/color" - "log" + "github.com/hajimehoshi/ebiten/v2" "gonum.org/v1/plot" "gonum.org/v1/plot/plotter" - "gonum.org/v1/plot/vg" "gonum.org/v1/plot/vg/draw" "gonum.org/v1/plot/vg/vgimg" ) -func Plot(xmin, xmax, xstep float64, f func(float64) float64) func(x float64) *image.RGBA { - var pts plotter.XYs - for x := xmin; x <= xmax; x += xstep { - pts = append(pts, plotter.XY{X: x, Y: f(x)}) - } - fn, err := plotter.NewLine(pts) - if err != nil { - log.Fatalf("Failed to NewLine: %v", err) +// Converting plot to ebiten.Image +func PlotToImage(p *plot.Plot) *ebiten.Image { + + img := image.NewRGBA(image.Rect(0, 0, sW, sH)) //creating image.RGBA to store the plot + + c := vgimg.NewWith(vgimg.UseImage(img)) //creating plot drawer for the image + + p.Draw(draw.New(c)) //drawing plot on the image + + return ebiten.NewImageFromImage(c.Image()) //converting image.RGBA to ebiten.Image (doing in Draw) + ///Black screen issue: was giving "img" instead of "c.Image()" in the function. +} + +// recreating plot with given data +func (a *App) updatePlot(k, b float64, px, py []float64) { + + p := plot.New() //initializing plot + + //################################################## + + //Line + + linePoints := plotter.XYs{ + {X: lineMin, Y: inference(lineMin, k, b)}, + {X: lineMax, Y: inference(lineMax, k, b)}, } - fn.Color = color.RGBA{B: 255, A: 255} - return func(x float64) *image.RGBA { - pts := plotter.XYs{{X: x, Y: f(x)}} - xScatter, err := plotter.NewScatter(pts) - if err != nil { - log.Fatalf("Failed to NewScatter: %v", err) - } - xScatter.Color = color.RGBA{R: 255, A: 255} - - labels, err := plotter.NewLabels(plotter.XYLabels{ - XYs: pts, - Labels: []string{fmt.Sprintf("x = %.2f", x)}, - }) - labels.Offset = vg.Point{X: -10, Y: 15} - if err != nil { - log.Fatalf("Failed to NewLabels: %v", err) - } - - p := plot.New() - p.Add( - plotter.NewGrid(), - fn, - xScatter, - labels, - ) - p.Legend.Add("f(x)", fn) - p.Legend.Add("x", xScatter) - p.X.Label.Text = "X" - p.Y.Label.Text = "Y" - - img := image.NewRGBA(image.Rect(0, 0, 640, 480)) - c := vgimg.NewWith(vgimg.UseImage(img)) - p.Draw(draw.New(c)) - return c.Image().(*image.RGBA) + + line, _ := plotter.NewLine(linePoints) //creating line + + p.Add(line) //adding line to the plot° + + //################################################## + + //Points + + var points plotter.XYs //initializing point plotter + + for i := 0; i < len(px); i++ { + points = append(points, plotter.XY{X: px[i], Y: py[i]}) //Saving all points in plotter } + + scatter, _ := plotter.NewScatter(points) //creating new scatter from point data + + p.Add(scatter) //adding points to plot + + //################################################## + + //App + + a.plot = p //replacing old plot with new one + }