Skip to content

Commit

Permalink
Merge pull request #33 from davewang/master
Browse files Browse the repository at this point in the history
add option once to recommend once to user
  • Loading branch information
zhenghaoz committed Feb 25, 2020
2 parents 52ea080 + 6dc621d commit 07ca0d4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 16 deletions.
64 changes: 57 additions & 7 deletions cmd/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
"net/http"
"strconv"
)

func serve(config engine.ServerConfig) {
var engineConfig engine.TomlConfig
func serve(config engine.TomlConfig) {
engineConfig = config
// Create a web service
ws := new(restful.WebService)
ws.Consumes(restful.MIME_JSON).Produces(restful.MIME_JSON)
Expand Down Expand Up @@ -68,8 +69,8 @@ func serve(config engine.ServerConfig) {
ws.Route(ws.GET("/status").To(getStatus))
// Start web service
restful.DefaultContainer.Add(ws)
log.Printf("start a server at %v\n", fmt.Sprintf("%s:%d", config.Host, config.Port))
log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", config.Host, config.Port), nil))
log.Printf("start a server at %v\n", fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port))
log.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port), nil))
}

// Status contains information about engine.
Expand Down Expand Up @@ -272,9 +273,58 @@ func getRecommends(request *restful.Request, response *restful.Response) {
internalServerError(response, err)
return
}
// Send result
items = engine.Ranking(items, number, p, t, c)
json(response, items)
if engineConfig.Recommend.Once {
// Get read recommended items
reads,err := db.GetIdentMap(engine.BucketReads, userId)
if reads == nil {
reads = make(map[string]bool)
}
var subItems []engine.RecommendedItem
change := false
notRecommended := false
if err != nil {
change = false
} else {
for i := range items {
exist := reads[items[i].ItemId]
if !exist {
subItems = append(subItems,items[i])
change = true
}
}
if !change{
notRecommended = true
}
}
if notRecommended {
var empty []engine.RecommendedItem
json(response, empty)
}else if change {
subItems = engine.Ranking(subItems, number, p, t, c)
for i := range subItems {
reads[subItems[i].ItemId] = true
}
if err := db.PutIdentMap(engine.BucketReads, userId, reads); err != nil {
badRequest(response, err)
}
// Send result
json(response, subItems)
} else {
// Send result
items = engine.Ranking(items, number, p, t, c)
for i := range items {
reads[items[i].ItemId] = true
}
if err := db.PutIdentMap(engine.BucketReads, userId, reads); err != nil {
badRequest(response, err)
}
json(response, items)
}
}else{
items = engine.Ranking(items, number, p, t, c)
json(response, items)
}

}

// Change contains information of changes after insert.
Expand Down
2 changes: 1 addition & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ var commandServe = &cobra.Command{
log.Printf("connect to database: %v", conf.Database.File)
// Start model daemon
go watch(conf, metaData)
serve(conf.Server)
serve(conf)
},
}

Expand Down
1 change: 1 addition & 0 deletions engine/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type RecommendConfig struct {
UpdateThreshold int `toml:"update_threshold"`
CheckPeriod int `toml:"check_period"`
FitJobs int `toml:"fit_jobs"`
Once bool `toml:"once"`
}

// ParamsConfig is the configuration for hyper-parameters of the recommendation model.
Expand Down
36 changes: 34 additions & 2 deletions engine/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
bktFeedback = "feedback" // Bucket name for feedback
BucketNeighbors = "neighbors" // Bucket name for neighbors
BucketRecommends = "recommends" // Bucket name for recommendations
BucketReads = "reads" // Bucket name for reads
bktUserFeedback = "user_feedback" // Bucket name for user feedback
)

Expand Down Expand Up @@ -67,7 +68,7 @@ func Open(path string) (*DB, error) {
}
// Create buckets
err = db.db.Update(func(tx *bolt.Tx) error {
bucketNames := []string{bktGlobal, bktItems, bktFeedback, BucketRecommends, BucketNeighbors, bktUserFeedback}
bucketNames := []string{bktGlobal, bktItems, bktFeedback, BucketRecommends,BucketReads, BucketNeighbors, bktUserFeedback}
for _, name := range bucketNames {
if _, err = tx.CreateBucketIfNotExists([]byte(name)); err != nil {
return err
Expand Down Expand Up @@ -480,7 +481,38 @@ func (db *DB) GetRandom(n int) ([]RecommendedItem, error) {
}
return items, nil
}

// Set item table for a user.
func (db *DB) PutIdentMap(bucketName string, id string, items map[string]bool) error {
return db.db.Update(func(tx *bolt.Tx) error {
// Get bucket
bucket := tx.Bucket([]byte(bucketName))
// Marshal data into bytes
buf, err := json.Marshal(items)
if err != nil {
return err
}
// Persist bytes to bucket
return bucket.Put([]byte(id), buf)
})
}
// Get item table for a user.
func (db *DB) GetIdentMap(bucketName string, id string) (map[string]bool, error) {
var items map[string]bool
err := db.db.View(func(tx *bolt.Tx) error {
// Get bucket
bucket := tx.Bucket([]byte(bucketName))
// Unmarshal data into bytes
buf := bucket.Get([]byte(id))
if buf == nil {
return fmt.Errorf("%v not found", id)
}
return json.Unmarshal(buf, &items)
})
if err != nil {
return nil, err
}
return items, nil
}
// SetRecommends sets recommendations for a user.
func (db *DB) PutIdentList(bucketName string, id string, items []RecommendedItem) error {
return db.db.Update(func(tx *bolt.Tx) error {
Expand Down
25 changes: 21 additions & 4 deletions engine/offline.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ func UpdateNeighbors(name string, cacheSize int, dataSet core.DataSetInterface,
}
return nil
}

// UpdateRecommends updates personalized recommendations for the database.
func UpdateRecommends(name string, params base.Params, cacheSize int, fitJobs int, dataSet core.DataSetInterface, db *DB) error {
func UpdateRecommends(name string, params base.Params, cacheSize int, fitJobs int,once bool, dataSet core.DataSetInterface, db *DB) error {
// Create model
log.Printf("create model %v with params = %v\n", name, params)
model := LoadModel(name, params)
Expand All @@ -98,7 +97,25 @@ func UpdateRecommends(name string, params base.Params, cacheSize int, fitJobs in
for userIndex := 0; userIndex < dataSet.UserCount(); userIndex++ {
userId := dataSet.UserIndexer().ToID(userIndex)
exclude := dataSet.UserByIndex(userIndex)
recommendItems, ratings := core.Top(items, userId, cacheSize, exclude, model)
// get read items
subItems := make(map[string]bool)
if once {
reads,err := db.GetIdentMap(BucketReads, userId)
if err != nil {
// reads is empty
subItems = items
}else{
for itemID := range items {
exist := reads[itemID]
if !exist {
subItems[itemID] = true
}
}
}
} else {
subItems = items
}
recommendItems, ratings := core.Top(subItems, userId, cacheSize, exclude, model)
recommends := make([]RecommendedItem, len(recommendItems))
items, err := db.GetItemsByID(recommendItems)
if err != nil {
Expand Down Expand Up @@ -137,7 +154,7 @@ func Update(config TomlConfig, metaData toml.MetaData, db *DB) error {
}
// Generate recommends
params := config.Params.ToParams(metaData)
if err = UpdateRecommends(config.Recommend.Model, params, config.Recommend.CacheSize, config.Recommend.FitJobs,
if err = UpdateRecommends(config.Recommend.Model, params, config.Recommend.CacheSize, config.Recommend.FitJobs, config.Recommend.Once,
dataSet, db); err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions engine/offline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestUpdateRecommends(t *testing.T) {
base.InitMean: 0,
base.InitStdDev: 0.001,
}
if err = UpdateRecommends("bpr", params, 10, runtime.NumCPU(), trainSet, db); err != nil {
if err = UpdateRecommends("bpr", params, 10, runtime.NumCPU(),false, trainSet, db); err != nil {
t.Fatal(err)
}
// Check result
Expand Down Expand Up @@ -230,7 +230,7 @@ func TestUpdateRecommendsInvalidModel(t *testing.T) {
if err = db.InsertItems(itemId, nil); err != nil {
t.Fatal(err)
}
if err = UpdateRecommends("invalid-model", nil, 10, runtime.NumCPU(), dataSet, db); err == nil {
if err = UpdateRecommends("invalid-model", nil, 10, runtime.NumCPU(), false,dataSet, db); err == nil {
t.Fatal("function should return an error")
}
// Close database
Expand Down

0 comments on commit 07ca0d4

Please sign in to comment.