Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Nov 5, 2019
2 parents a8aca64 + 3abae4b commit 66bb038
Show file tree
Hide file tree
Showing 35 changed files with 517 additions and 354 deletions.
19 changes: 14 additions & 5 deletions base/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@ import "log"
// RuntimeOptions defines options used for runtime.
type RuntimeOptions struct {
Verbose bool // Verbose switch
NJobs int // Number of jobs
FitJobs int // Number of jobs for model fitting
CVJobs int // Number of jobs for cross validation
}

// GetJobs returns the number of concurrent jobs.
func (options *RuntimeOptions) GetJobs() int {
if options == nil || options.NJobs < 1 {
// GetFitJobs returns the number of concurrent jobs for model fitting.
func (options *RuntimeOptions) GetFitJobs() int {
if options == nil || options.FitJobs < 1 {
return 1
}
return options.NJobs
return options.FitJobs
}

// GetCVJobs returns the number of concurrent jobs for cross validation.
func (options *RuntimeOptions) GetCVJobs() int {
if options == nil || options.CVJobs < 1 {
return 1
}
return options.CVJobs
}

// GetVerbose returns the indicator of verbose.
Expand Down
8 changes: 5 additions & 3 deletions base/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ func TestRuntimeOptions(t *testing.T) {
// Check default
var opt1 *RuntimeOptions
assert.Equal(t, true, opt1.GetVerbose())
assert.Equal(t, 1, opt1.GetJobs())
assert.Equal(t, 1, opt1.GetFitJobs())
assert.Equal(t, 1, opt1.GetCVJobs())
// Check options
opt2 := &RuntimeOptions{false, 10}
opt2 := &RuntimeOptions{false, 10, 5}
assert.Equal(t, false, opt2.GetVerbose())
assert.Equal(t, 10, opt2.GetJobs())
assert.Equal(t, 10, opt2.GetFitJobs())
assert.Equal(t, 5, opt2.GetCVJobs())
}
16 changes: 16 additions & 0 deletions base/parallel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package base

import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat"
"sync"
)
Expand Down Expand Up @@ -37,6 +38,21 @@ func ParallelFor(begin, end int, worker func(i int)) {
wg.Wait()
}

// ParallelForSum runs for loop in parallel.
func ParallelForSum(begin, end int, worker func(i int) float64) float64 {
retValues := make([]float64, end-begin)
var wg sync.WaitGroup
wg.Add(end - begin)
for j := begin; j < end; j++ {
go func(i int) {
retValues[i] = worker(i)
wg.Done()
}(j)
}
wg.Wait()
return floats.Sum(retValues)
}

// ParallelMean schedules and runs tasks in parallel, then returns the mean of returned values.
// nJob is the number of executors. worker is the executed function which passed a range of task
// IDs (begin, end) and returns a double value.
Expand Down
4 changes: 2 additions & 2 deletions base/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ const (

// Predefined values for hyper-parameter Optimizer.
const (
SGD string = "sgd" // Fit model (MF) with stochastic gradient descent.
BPR string = "bpr" // Fit model (MF) with bayesian personal ranking.
SGDOptimizer string = "sgd" // Fit model (FM) with stochastic gradient descent.
BPROptimizer string = "bpr" // Fit model (FM) with bayesian personal ranking.
)

// Predefined values for hyper-parameter Similarity.
Expand Down
82 changes: 52 additions & 30 deletions cmd/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,48 @@ func serve(config engine.ServerConfig) {

// Status contains information about engine.
type Status struct {
FeedbackCount int // number of feedback
ItemCount int // number of items
CommitCount int // number of committed feedback
FeedbackCount int // number of feedback
ItemCount int // number of items
UserCount int // number of users
CommitCount int // number of committed feedback
CommitTime string // time for commit
}

func getStatus(request *restful.Request, response *restful.Response) {
func status() (Status, error) {
status := Status{}
var err error
// Get feedback count
if status.FeedbackCount, err = db.CountFeedback(); err != nil {
internalServerError(response, err)
return status, err
}
// Get item count
if status.ItemCount, err = db.CountItems(); err != nil {
internalServerError(response, err)
return status, err
}
// Get user count
if status.UserCount, err = db.CountUsers(); err != nil {
return status, err
}
// Get commit count
var commit string
if commit, err = db.GetMeta("commit"); err != nil {
internalServerError(response, err)
return status, err
}
if status.CommitCount, err = strconv.Atoi(commit); len(commit) > 0 && err != nil {
return status, err
}
// Get commit time
if status.CommitTime, err = db.GetMeta("commit_time"); err != nil {
return status, err
}
return status, nil
}

func getStatus(request *restful.Request, response *restful.Response) {
status, err := status()
if err != nil {
internalServerError(response, err)
return
}
json(response, status)
}
Expand Down Expand Up @@ -217,6 +236,8 @@ func getRecommends(request *restful.Request, response *restful.Response) {
type Change struct {
ItemsBefore int // number of items before change
ItemsAfter int // number of items after change
UsersBefore int // number of users before change
UsersAfter int // number of user after change
FeedbackBefore int // number of feedback before change
FeedbackAfter int // number of feedback after change
}
Expand All @@ -231,29 +252,32 @@ func putItems(request *restful.Request, response *restful.Response) {
}
var err error
change := Change{}
change.FeedbackBefore, err = db.CountFeedback()
if err != nil {
internalServerError(response, err)
return
}
change.FeedbackAfter = change.FeedbackBefore
change.ItemsBefore, err = db.CountItems()
// Get status before change
stat, err := status()
if err != nil {
internalServerError(response, err)
return
}
change.FeedbackBefore = stat.FeedbackCount
change.ItemsBefore = stat.ItemCount
change.UsersBefore = stat.UserCount
// Insert items
for _, itemId := range *items {
err = db.InsertItem(itemId)
if err != nil {
internalServerError(response, err)
return
}
}
change.ItemsAfter, err = db.CountItems()
// Get status after change
stat, err = status()
if err != nil {
internalServerError(response, err)
return
}
change.FeedbackAfter = stat.FeedbackCount
change.ItemsAfter = stat.ItemCount
change.UsersAfter = stat.UserCount
json(response, change)
}

Expand All @@ -273,36 +297,34 @@ func putFeedback(request *restful.Request, response *restful.Response) {
return
}
var err error
status := Change{}
status.FeedbackBefore, err = db.CountFeedback()
if err != nil {
internalServerError(response, err)
return
}
status.FeedbackAfter = status.FeedbackBefore
status.ItemsBefore, err = db.CountItems()
change := Change{}
// Get status before change
stat, err := status()
if err != nil {
internalServerError(response, err)
return
}
change.FeedbackBefore = stat.FeedbackCount
change.ItemsBefore = stat.ItemCount
change.UsersBefore = stat.UserCount
// Insert feedback
for _, feedback := range *ratings {
err = db.InsertFeedback(feedback.UserId, feedback.ItemId, feedback.Feedback)
if err != nil {
internalServerError(response, err)
return
}
}
status.FeedbackAfter, err = db.CountFeedback()
// Get status after change
stat, err = status()
if err != nil {
internalServerError(response, err)
return
}
status.ItemsAfter, err = db.CountItems()
if err != nil {
internalServerError(response, err)
return
}
json(response, status)
change.FeedbackAfter = stat.FeedbackCount
change.ItemsAfter = stat.ItemCount
change.UsersAfter = stat.UserCount
json(response, change)
}

func badRequest(response *restful.Response, err error) {
Expand Down
5 changes: 4 additions & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ func watch(config engine.TomlConfig, metaData toml.MetaData) {
} else if err != nil {
log.Fatal(err)
}
log.Printf("current number of feedback: %v, commit: %v\n", count, lastCount)
// Compare
if count-lastCount > config.Recommend.UpdateThreshold {
log.Printf("current count (%v) - commit (%v) > threshold (%v), start to update recommends\n",
Expand All @@ -89,6 +88,10 @@ func watch(config engine.TomlConfig, metaData toml.MetaData) {
if err = db.SetMeta("commit", strconv.Itoa(count)); err != nil {
log.Fatal(err)
}
t := time.Now()
if err = db.SetMeta("commit_time", t.String()); err != nil {
log.Fatal(err)
}
log.Printf("recommends update-to-date, commit = %v", count)
}
// Sleep
Expand Down
13 changes: 12 additions & 1 deletion cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func init() {
commandTest.PersistentFlags().Bool(paramFlag.Name, false, paramFlag.Help)
}
}
// Runtime options
commandTest.PersistentFlags().BoolP("verbose", "v", true, "verbose")
commandTest.PersistentFlags().Int("fit-jobs", 1, "number of jobs for model fitting")
commandTest.PersistentFlags().IntP("cv-jobs", "j", 1, "number of jobs for cross validation")
}

var commandTest = &cobra.Command{
Expand Down Expand Up @@ -127,8 +131,15 @@ var commandTest = &cobra.Command{
if len(rankMetrics) > 0 {
evaluators = append(evaluators, core.NewRankEvaluator(n, rankMetrics...))
}
// Load runtime options
verbose, _ := cmd.PersistentFlags().GetBool("verbose")
fitJobs, _ := cmd.PersistentFlags().GetInt("fit-jobs")
cvJobs, _ := cmd.PersistentFlags().GetInt("cv-jobs")
options := &base.RuntimeOptions{verbose, fitJobs, cvJobs}
log.Printf("Runtime options: verbose = %v, fit_jobs = %v, cv_jobs = %v\n",
options.GetVerbose(), options.GetFitJobs(), options.GetCVJobs())
// Cross validation
out := core.CrossValidate(model, data, core.NewKFoldSplitter(5), 0, nil, evaluators...)
out := core.CrossValidate(model, data, core.NewKFoldSplitter(5), 0, options, evaluators...)
// Render table
header := make([]string, k+2)
header[k+1] = "Mean"
Expand Down
2 changes: 1 addition & 1 deletion core/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func CrossValidate(model ModelInterface, dataSet DataSetInterface, splitter Spli
scores := make([][]float64, length)
costs := make([][]float64, length)
params := model.GetParams()
base.Parallel(length, options.GetJobs(), func(begin, end int) {
base.Parallel(length, options.GetCVJobs(), func(begin, end int) {
cp := reflect.New(reflect.TypeOf(model).Elem()).Interface().(ModelInterface)
Copy(cp, model)
cp.SetParams(params)
Expand Down
10 changes: 0 additions & 10 deletions docs/develop/index.rst

This file was deleted.

59 changes: 0 additions & 59 deletions docs/develop/model_selection.rst

This file was deleted.

0 comments on commit 66bb038

Please sign in to comment.